#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/functions/utils.h>


#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/BatchedTensorImpl.h>
#include <ATen/core/grad_mode.h>
#include <ATen/core/Reduction.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/ScalarOps.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/grad_mode.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <c10/util/SmallBuffer.h>


#include <ciso646>
#include <algorithm>
#include <numeric>
#include <functional>
// Helper functions for autogenerated code
// These used to be inlined into the codegened Functions.cpp

namespace torch {
namespace autograd {
namespace generated {
namespace details {

using at::Tensor;
using at::Scalar;
using at::IntArrayRef;
using at::TensorList;
using at::areAnyTensorSubclassLike;

const char* kCudnnDoubleBackwardMsg = "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n    output = model(inputs)";

namespace {
  static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
    if (reduction == at::Reduction::Mean) {
      return unreduced.mean();
    } else if (reduction == at::Reduction::Sum) {
      return unreduced.sum();
    }
    return unreduced;
  }
}

bool isDefined(const c10::optional<Tensor>& t) {
  return t.has_value() && t->defined();
}

Tensor toNonOptTensor(const c10::optional<Tensor>& t) {
  return t.has_value() ? *t : Tensor();
}

Tensor toNonOptFwGrad(const c10::optional<Tensor>& t) {
  return (t.has_value() && t->defined()) ? t->_fw_grad(/*level */ 0) : Tensor();
}

Tensor toNonOptPrimal(const c10::optional<Tensor>& t) {
  return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor();
}

void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
  AT_ASSERT(range.second <= out.size());
  AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
  out[range.first] = t;
}

void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
  AT_ASSERT(range.second <= out.size());
  AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output");
  std::copy(t.begin(), t.end(), out.begin() + range.first);
}

Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) {
  auto ratio = result / self;
  ratio.masked_fill_(self == 0, 0);
  return grad * ratio;
}

template <typename T>
T not_implemented_base(const char* name, const char* reason) {
  std::string msg = c10::str("the derivative for '", name, "' is not implemented.");
  if (strlen(reason) > 0) {
    msg = c10::str(msg, " ", reason);
  };
  TORCH_CHECK_NOT_IMPLEMENTED(false, msg);
}

Tensor not_implemented(const char* name, const char* reason) {
  return not_implemented_base<Tensor>(name, reason);
}

std::vector<Tensor> not_implemented_list(const char* name, const char* reason) {
  return not_implemented_base<std::vector<Tensor>>(name, reason);
}

Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
  bool is_one = false;
  if (s.isFloatingPoint()) {
    is_one = s.toDouble() == 1;
  } else if(s.isIntegral(true)) {
    is_one = s.toLong() == 1;
  }

  if (is_one) {
    return t;
  } else {
    return t * s;
  }
}

int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
  int64_t size = 1;
  if (sizes.size() == 0) {
    return 1;
  }
  for (auto d : dim) {
    d = at::maybe_wrap_dim(d, sizes.size());
    size *= sizes[d];
  }
  return size;
}

Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
  if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
    // R -> C
    return at::real(gradient_result);
  }
  return gradient_result;
}

Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
  if (!self.is_complex() && gradient_result.is_complex()) {
    // R -> C
    return at::real(gradient_result);
  }
  return gradient_result;
}

Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) {
  if (keepdim) {
    return output;
  }
  int64_t total_dims = output.dim() + dims.size();
  std::vector<int64_t> target_shape(total_dims, 0);
  for (int64_t i : dims) {
    if (i < 0) {
      i = total_dims + i;
    }
    target_shape[i] = 1;
  }
  int64_t j = 0;
  for (int64_t i : output.sizes()) {
    while (target_shape[j] > 0) j++;
    target_shape[j++] = i;
  }
  return output.reshape(target_shape);
}

Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims) {
  return (grad / mask.sum(dims, true)) * mask;
}

std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) {
  if (!grad.defined()) {
    return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
  }
  // handle case at 0 where we return a subgradient containing 0
  Tensor ratio = grad / res;
  ratio.masked_fill_(res == 0, 0);
  return std::tuple<Tensor, Tensor>{
            x1 * ratio.sum(-1, true) - ratio.matmul(x2),
            x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)};
}

Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional<Scalar> & p_, const Tensor& norm) {
  return norm_backward(grad, self, p_, norm, {}, true);
}

Tensor norm_backward(Tensor grad, const Tensor& self, const optional<Scalar> & p_, Tensor norm, IntArrayRef dim, bool keepdim) {
  size_t ndim = self.sizes().size();
  double p = p_.value_or(2.0).toDouble();
  Tensor self_scaled;
  Tensor scale_v;

  if (!keepdim && self.dim() != 0) {
    grad = unsqueeze_multiple(grad, dim, ndim);
    norm = unsqueeze_multiple(norm, dim, ndim);
  }

  if (p == 0.0) {
    return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else if (p == 1.0) {
    return self.sgn() * grad;
  } else if (p == 2.0) {
    self_scaled = self;
    scale_v = grad / norm;
  } else if (std::isinf(p)) {
    const auto self_isnan = self.isnan();
    const auto norm_isnan = norm.isnan();
    const auto& self_and_norm_isnan = areAnyTensorSubclassLike({self, norm}) ?
      self_isnan.logical_and(norm_isnan) :
      self_isnan.logical_and_(norm_isnan);
    Tensor is_eq_max = (self.abs() == norm).logical_or_(self_and_norm_isnan).type_as(self);
    self_scaled = self.sgn() * is_eq_max;
    Tensor nb_max = is_eq_max.count_nonzero(dim);
    if (self.dim() != 0) {
      nb_max = unsqueeze_multiple(nb_max, dim, ndim);
    }
    scale_v = grad / nb_max;
  } else if (p < 2.0) {
    self_scaled = self.sgn() * self.abs().pow(p - 1);
    scale_v = grad / norm.pow(p - 1);
  } else {
    self_scaled = self * self.abs().pow(p - 2);
    scale_v = grad / norm.pow(p - 1);
  }
  // handle case at 0 where we return a subgradient containing 0
  scale_v.masked_fill_(norm == 0, 0);
  return self_scaled * scale_v;
}

Tensor linalg_vector_norm_backward(Tensor grad, const Tensor& self, const Scalar& scalar_ord, Tensor norm, const optional<IntArrayRef>& opt_dim, bool keepdim) {
  auto dim = opt_dim.value_or(IntArrayRef({}));
  return norm_backward(grad, self, scalar_ord, norm, dim, keepdim);
}

Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) {
  if (exponent.equal(0.0)) {
    return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); };
    Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble());
    return handle_r_to_c(self, out);
  }
}

Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
  auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj());
  return handle_r_to_c(self, out);
}

// Caveats:
// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
// d(a^b)/db -> -inf for a fixed b as a -> +0
// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
//
// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
// d(a^b)/db = 0 for a > 0 and b -> +0.
// Currently, tensorflow agrees with us.
Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) {
  Tensor cond;
  if (exponent.is_complex()) {
    auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
    cond = at::logical_and(self == 0, is_real_exp);
  } else {
    cond = at::logical_and(self == 0, exponent >= 0);
  }
  auto out = grad * at::where(cond,
                          at::zeros({}, grad.options()),
                          (result * self.log()).conj());
  return handle_r_to_c(exponent, out);
}

Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
  auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); };
  if (base.equal(0.0)) {
    auto cond = [](auto exp) {
      if (exp.is_complex()) {
        return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
      } else {
        return exp >=0;
      }
    };
    auto out = grad * at::where(cond(exponent),
                            at::zeros({}, grad.options()),
                            grad_lambda(result, base));
    return handle_r_to_c(exponent, out);
  } else {
    auto out = grad * grad_lambda(result, base);
    return handle_r_to_c(exponent, out);
  }
}

Tensor angle_backward(Tensor grad, const Tensor& self) {
  if (self.is_complex()) {
    return at::where(self == 0.0, at::zeros({}, self.options()),
                     grad * self / self.abs().pow(2) * Scalar(c10::complex<double>{0.0, 1.0}));
  } else {
    return at::zeros_like(self, at::MemoryFormat::Preserve);
  }
}

Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
  Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
  args = args.add(self.unsqueeze(-1));
  return grad * args.digamma_().sum(-1);
}

Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) {
  if (self.is_complex()) {
    auto abs = at::abs(self);
    // C -> C
    // https://arxiv.org/pdf/1701.00392.pdf Section 4.20
    return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result)));
  } else {
    return at::zeros_like(self, at::MemoryFormat::Preserve);
  }
}

Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
  auto out = grad * other.conj();
  return handle_r_to_c(self_st, out);
}

Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st, const c10::optional<c10::string_view>& rounding_mode) {
  if (rounding_mode.has_value()) {
    return at::zeros_like(grad, grad.options().dtype(self_st));
  }

  auto result = grad / other.conj();
  return handle_r_to_c(self_st, result);
}

Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) {
  return div_tensor_self_backward(grad, other, self_st, c10::nullopt);
}

Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other, const c10::optional<c10::string_view>& rounding_mode) {
  if (rounding_mode.has_value()) {
    return at::zeros_like(grad, grad.options().dtype(other.scalar_type()));
  }

  auto result = -grad * ((self / other) / other).conj();
  return handle_r_to_c(other, result);
}

Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) {
  return div_tensor_other_backward(grad, self, other, c10::nullopt);
}

Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
  // invert the permutation
  auto ndims = fwd_dims.size();
  std::vector<int64_t> dims(ndims);
  for(const auto i : c10::irange(ndims)) {
    dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
  }
  return grad.permute(dims);
}

Tensor rad2deg_backward(const Tensor& grad) {
  constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564;
  return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_180_PI)));
}

Tensor deg2rad_backward(const Tensor& grad) {
  constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417;
  return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180)));
}

Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
    Tensor res = t;
    for(const auto i : c10::irange(n_dims)){
      if (dims_to_unsqueeze[i]) {
        res = res.unsqueeze(i);
      }
    }
    return res;
}

Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
  if (!keepdim && sizes.size() > 0) {
    if (dims.size()==1) {
      return grad.unsqueeze(dims[0]).expand(sizes);
    } else {
      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
      return res.expand(sizes);
    }
  } else {
    return grad.expand(sizes);
  }
}

Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) {
  auto sizes = self.sizes();
  if (!keepdim && sizes.size() > 0) {
    if (dims.size()==1) {
      return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not();
    } else {
      Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
      return res.expand(sizes) * self.isnan().logical_not();
    }
  } else {
    return grad.expand(sizes) * self.isnan().logical_not();
  }
}

std::vector<int64_t> reverse_list(const IntArrayRef list) {
  auto result = std::vector<int64_t>();
  result.reserve(list.size());
  for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
    result.push_back(*iter);
  }
  return result;
}

Tensor reverse_dim(const Tensor& t, int64_t dim) {
  Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong));
  return t.index_select(dim, index);
}

Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) {
  if (inp.size(dim) == 1) {
    return grad;
  }

  auto ones_size = inp.sizes().vec();
  ones_size[dim] = 1;
  Tensor ones = at::ones(ones_size, grad.options());
  Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
  Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);

  Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim);
  Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
  Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);

  return grad * (exclusive_normal * exclusive_reverse).conj();
}

// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input:                        [    a,     b,     c]
// cumprod(exclusive, normal):   [1    ,     a, a * b]
// cumprod(exclusive, reverse):  [b * c,     c,     1]
// product:                      [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) {
  if (input.dim() == 0) {
    return grad;
  }
  Tensor zero_idx = (input == 0).nonzero();
  if (zero_idx.numel() == 0) {
    return grad * (result / input).conj();
  } else if (zero_idx.size(0) > 1) {
    return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input);
  }
}

Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) {
  if (input.dim() == 0) {
    return grad;
  }
  dim = at::maybe_wrap_dim(dim, input.sizes().size());
  if (!keepdim && input.dim() != 1) {
    grad = grad.unsqueeze(dim);
    result = result.unsqueeze(dim);
  }

  Tensor zero_mask = (input == 0);
  Tensor slice_zero_count = zero_mask.sum(dim, true);
  int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
  if (total_zeros == 0) {
    return grad * (result / input).conj();
  } else {
    return prod_safe_zeros_backward(grad, input, dim);
  }
}

template <typename solve_f>
static Tensor generic_solve_jvp(
  solve_f solve,
  const Tensor& X, const Tensor& A,
  const Tensor& dA, const Tensor& dB) {
  auto is_vector_case = at::native::linalg_solve_is_vector_rhs(dA, dB);
  auto dA_contrib = is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X);
  // In general,
  // dX = solve(A, dB - dA_contrib), but this behavior is different for lu_solve.
  // For refer to lu_solve_jvp for more details on this.
  return solve(A, dB, dA_contrib);
}

Tensor solve_jvp(
  const Tensor& X,
  const Tensor& A,
  const Tensor& dA,
  const Tensor& dB
) {
  return generic_solve_jvp(
    [](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
      return at::linalg_solve(A, dB - dA_contrib);
    },
    X, A, dA, dB
  );
}

Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
  return at::linalg_solve(A.mH(), grad);
}

Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
  Tensor grad_self = solve_backward_self(grad, self, A);
  if (self.ndimension() == 2 && A.ndimension() == 2) {
    return -at::mm(grad_self, solution.mH());
  }
  // if self was unsqueezed from (..., M) to (..., M, 1)
  bool vector_case = at::native::linalg_solve_is_vector_rhs(A, self);
  if (vector_case) {
    return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).mH());
  }
  return -at::matmul(grad_self, solution.mH());
}

Tensor cumsum_backward(const Tensor & grad, int64_t dim) {
  // Trivial case
  if (grad.numel() <= 1 || grad.size(dim) == 1) {
    return grad;
  }
  return grad.flip(dim).cumsum(dim).flip(dim);
}

Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
  if (!keepdim && self.dim() != 0) {
    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
    result = unsqueeze_multiple(result, dim, self.sizes().size());
  }
  return grad * (self - result).exp();
}

Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) {
  if (grad.dim() == 0 || grad.numel() == 0) {
    return grad;
  }

  // Reference: https://github.com/tensorflow/tensorflow/blob/
  // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
  return AT_DISPATCH_FLOATING_TYPES(
      at::typeMetaToScalarType(grad.dtype()),
      "logcumsumexp_backward",
      [grad, self, result, dim]() {
        auto grad_min = at::empty_like(grad);
        grad_min.fill_(std::numeric_limits<scalar_t>::lowest());
        auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min);
        auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min);

        auto reverse_logcumsumexp = [dim](auto x) {
          return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
        };

        auto output_pos =
            (reverse_logcumsumexp(log_grad_positive - result) + self).exp();
        auto output_neg =
            (reverse_logcumsumexp(log_grad_negative - result) + self).exp();

        return output_pos - output_neg;
      });
}

Tensor unbind_backward(const variable_list& grads, int64_t dim) {
  IntArrayRef sizes;
  at::TensorOptions o;
  for (const auto& v : grads) {
    if (v.defined()) {
      sizes = v.sizes();
      o = static_cast<Tensor>(v).options();
      break;
    }
  }
  auto grads_tensors = fmap(grads, [&](const Variable& v) {
    return (
        v.defined() ? static_cast<Tensor>(v) : at::zeros({}, o).expand(sizes));
  });
  return at::stack(grads_tensors, dim);
}

Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
  auto result = self;

  int64_t nDims = sizes.size();
  for(const auto dim : c10::irange(nDims)) {
    if (sizes[dim] == 1) {
      result = result.unsqueeze(dim);
    }
  }
  return result;
}

Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
  dim = at::maybe_wrap_dim(dim, sizes.size());
  // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
  // unsqueezing in the backward.
  if (sizes.size() > 0 && sizes[dim] == 1) {
    return self.unsqueeze(dim);
  }
  return self;
}

std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, const std::vector<ScalarType> &dtypes, int64_t dim) {
  std::vector<Tensor> grad_inputs(sizes.size());
  if (!grad.defined()) {
    return grad_inputs;
  }
  dim = at::legacy_cat_wrap_dim(dim, sizes);
  int64_t accumulate = 0;

  Tensor grad_;
  bool grad_is_complex = grad.is_complex();
  if (grad_is_complex) {
    grad_ = at::real(grad);
  }
  for (const auto i : c10::irange(sizes.size())) {
    Tensor grad_val;
    if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
      // R -> C
      grad_val = grad_;
    } else {
      grad_val = grad;
    }
    auto& shape = sizes[i];
    // If input was empty tensor, gradInput should be empty tensor.
    if (shape == std::vector<int64_t>({0})) {
      grad_inputs[i] = at::zeros({0}, grad_val.options());
      continue;
    }
    auto size = shape[dim];
    accumulate += size;
    grad_inputs[i] = grad_val.narrow(dim, accumulate - size, size);
  }
  return grad_inputs;
}

Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) {
  // clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
  if (max && min) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where((self >= *min).logical_and_(self <= *max), grad, zero);
  } else if (min) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self >= *min, grad, zero);
  } else if (max) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self <= *max, grad, zero);
  } else {
    return grad;
  }
}

Tensor clamp_backward(const Tensor & grad, const Tensor &self, const Tensor& min, const Tensor& max) {
  // clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
  if (max.defined() && min.defined()) {
    auto zero = at::scalar_tensor(0., grad.options());
    const auto self_ge_min = self >= min;
    const auto self_le_max = self <= max;
    const auto& pred = areAnyTensorSubclassLike({self, min, max}) ?
      self_ge_min.logical_and(self_le_max) :
      self_ge_min.logical_and_(self_le_max);
    return where(pred, grad, zero);
  } else if (min.defined()) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self >= min, grad, zero);
  } else if (max.defined()) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self <= max, grad, zero);
  } else {
    return grad;
  }
}

std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
    const Tensor& grad, const Tensor& self, const Tensor& min, const Tensor& max,
    const std::array<bool, 2>& grad_input_mask) {
  // If min > max, min has no gradient
  std::tuple<at::Tensor, at::Tensor> ret;
  if (!grad.defined()) {
    return ret;
  }

  auto zero = at::scalar_tensor(0., grad.options());
  if (max.defined() && min.defined()) {
    if (grad_input_mask[0]) {
      const auto self_lt_min = self < min;
      const auto min_lt_max = min < max;
      const auto& pred = areAnyTensorSubclassLike({self, min, max}) ?
        self_lt_min.logical_and(min_lt_max) :
        self_lt_min.logical_and_(min_lt_max);
      std::get<0>(ret) = where(pred, grad, zero);
    }
    if (grad_input_mask[1]) {
      const auto self_gt_max = self > max;
      const auto max_lt_min = max < min;
      const auto& pred = areAnyTensorSubclassLike({self, min, max}) ?
        self_gt_max.logical_or(max_lt_min) :
        self_gt_max.logical_or_(max_lt_min);
      std::get<1>(ret) = where(pred, grad, zero);
    }
  } else if (min.defined() && grad_input_mask[0]) {
    std::get<0>(ret) = where(self < min, grad, zero);
  } else if (max.defined() && grad_input_mask[1]) {
    std::get<1>(ret) = where(self > max, grad, zero);
  }
  return ret;
}

Tensor convolution_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_p, const Tensor& bias_t,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
    bool transposed, IntArrayRef output_padding, int64_t groups) {
  auto bias_t_opt = bias_t.defined() ? c10::optional<at::Tensor>(bias_t) : c10::nullopt;
  return (
      at::convolution(input_t, weight_p, c10::nullopt, stride, padding, dilation, transposed, output_padding, groups)
    + at::convolution(input_p, weight_t, bias_t_opt, stride, padding, dilation, transposed, output_padding, groups));
}

Tensor _convolution_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_p, const Tensor& bias_t,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
    bool transposed, IntArrayRef output_padding, int64_t groups,
    bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
  auto bias_t_opt = bias_t.defined() ? c10::optional<at::Tensor>(bias_t) : c10::nullopt;
  return (
      at::_convolution(
        input_t, weight_p, c10::nullopt, stride, padding, dilation, transposed, output_padding,
        groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
    + at::_convolution(
        input_p, weight_t, bias_t_opt, stride, padding, dilation, transposed, output_padding,
        groups, benchmark, deterministic, cudnn_enabled, allow_tf32));
}

Tensor convolution_backward_jvp_grad_bias(
    const Tensor& grad_out_t,
    const Tensor& grad_bias) {
  if (!grad_bias.defined()) {
    return Tensor();
  }
  int64_t dim = grad_out_t.dim() - 2;
  if (dim == 1) {
    // Cannot pass initializer list due to overload ambiguity
    auto dimlist = std::vector<int64_t>{0, 2};
    return grad_out_t.sum(dimlist);
  } else if (dim == 2) {
    return grad_out_t.sum({0, 2, 3});
  } else if (dim == 3) {
    return grad_out_t.sum({0, 2, 3, 4});
  } else {
    TORCH_INTERNAL_ASSERT(
        false,
        "convolution_backward_jvp_grad_bias expected dim of grad_out_t to be 3, 4, or 4, but got: ",
        grad_out_t.dim());
  }
}

// This function is used by load_derivatives.py to replace tensor.strides()
// calls that appear in derivative formulas. If the tensor has requires_grad
// set, this function returns its strides or throws an error if the tensor
// is sparse. If requires_grad is not set, an empty array is returned since
// there will be no backward pass. There has one special case, if input is MKLDNN
// tensor and has requires_grad set, just return an empty array, the reason is
// that MKLDNN tensor is a opaque tensor which has not stride info.
//
// This function only supports the case where `input` is the tensor whose
// single derivative is being calculated.
//
// This function does not support `self` derivatives for inplace functions.
//
// Args:
//  input              Tensor to call .strides() on
//  input_name         Name of `input` tensor, from derivative formula
at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name) {
  // TODO: Ideally, this function would never be called if requires_grad is
  // not set. Once codegen is updated to avoid the call, we can remove this
  // check.
  if (input.requires_grad()) {
    TORCH_CHECK(
      !input.is_sparse(),
      "The backward pass for this operation requires the '", input_name,
      "' tensor to be strided, but a sparse tensor was given instead. ",
      "Please either use a strided tensor or set requires_grad=False for '",
      input_name, "'");
    if (input.is_mkldnn()) return IntArrayRef({});
    return input.strides();
  } else {
    return IntArrayRef({});
  }
}

Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha) {
  // if input was column-major, return grad as column-order for efficiency
  if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) {
    return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha.conj());
  } else {
    return maybe_multiply(grad.mm(mat2.t().conj()), alpha.conj());
  }
}

Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) {
  // if input was column-major, return grad as column-order for efficiency
  if (strides[0] == 1 && strides[1] == sizes[0]) {
    if (mat1.is_sparse()) {
      // Since mm(dense, sparse) doesn't exist,
      // pass a transposed output matrix to the underlying "addmm"
      // function directly.
      int64_t out_rows = mat1.size(1);
      int64_t out_cols = grad.size(1);
      Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true);
      Tensor r = at::empty({out_cols, out_rows}, grad.options()).t();
      at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
      return r;
    }
    return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha.conj());
  } else {
    return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj());
  }
}

Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) {
  AT_ASSERT(sparse_.is_sparse());
  auto sparse = sparse_.coalesce();
  Tensor grad_sparse = maybe_multiply(grad.mm(dense.conj().t()), alpha);
  return grad_sparse.sparse_mask(sparse);
}

// This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask`
// and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and  dense_dim=2,
// and they must have the same shape.
// Note that the `output` must have the same `indices` as the `mask` so we are using just a clone.
// However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values`
// That's why we created this `_sparse_mask_helper` function.
Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){
  Tensor output = at::empty_like(mask);
  Tensor mask_indices = mask._indices().clone();
  Tensor r_values;
  if (mask._nnz() == 0) {
    r_values = at::zeros_like(mask._values());
  } else {
    r_values = _sparse_mask_helper(input, mask_indices.contiguous());
  }
  at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values);
  return output;
}

Tensor sparse_sparse_matmul_backward(
    const Tensor& grad,
    const Tensor& a,
    const Tensor& b,
    int64_t grad_order) {
  /*
  To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition
  for dense tensors:

  c = a @ b
      then
  a_grad = c_grad @ b^H
  b_grad = a^H @ c_grad

  So for sparse matrices we can use the following definition:

  if grad_order == 0:
      a_grad = sparse_matrix_mask(c_grad @ b^H, mask=a)
  else:
      b_grad = sparse_matrix_mask(a^H @ c_grad, mask=b)
  */
  TORCH_CHECK(
      grad_order == 0 || grad_order == 1,
      ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
  if (grad_order == 0) {
    auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
    return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce());
  }
  auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad);
  return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce());
}

Tensor renorm_backward(const Tensor & grad, const Tensor & self, const Scalar& p_s, int64_t dim, const Scalar& maxnorm) {
  auto self_sizes = self.sizes();
  dim = c10::maybe_wrap_dim(dim, self_sizes.size());
  at::DimVector reduce_dims(self_sizes.size());
  std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
  reduce_dims.erase(reduce_dims.begin() + dim);
  auto dtype = self.scalar_type();
  auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
  const auto p = p_s.toDouble();

  Tensor norm;
  if (acc_type != dtype) {
    norm = at::linalg_vector_norm(
        self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type);
  } else {
    norm = at::linalg_vector_norm(
        self, p, reduce_dims, /*keepdim=*/true);
  }

  const auto real_acc_type = c10::toValueType(acc_type);
  auto grad_output = (self.conj() * grad);
  // vector_norm output is real, so grad_output must also be real
  if (real_acc_type != acc_type) {
    grad_output = at::real(grad_output);
  }
  grad_output = grad_output.sum(
      reduce_dims, /*keepdim=*/true, /*dtype=*/real_acc_type);
  auto nb = linalg_vector_norm_backward(
      grad_output, self, p, norm, reduce_dims, /*keepdim=*/true);

  auto invnorm = (norm + 1e-7).reciprocal();
  auto grad_norm = maxnorm * invnorm * (grad - invnorm * nb);
  return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
}

Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) {
  auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
  if (find_iter != repeats.cend()) {
    return at::zeros(input_shape, grad.options());
  }
  const auto input_dims = input_shape.size();
  int64_t num_unsqueezed = grad.dim() - input_dims;
  for (const auto i : c10::irange(num_unsqueezed)) {
    (void)i; // Suppress unused variable warning
    grad = grad.sum(0, false);
  }

  at::DimVector grad_size, sum_dims;
  for (const auto dim : c10::irange(input_dims)) {
    int64_t repeat = repeats[dim + num_unsqueezed];
    // Reshape gradient (repeat > 1)
    // Index:      [..., dim    , ...]    [..., dim   ,  dim+1        , ...]
    // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
    // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor.
    // Then, sum up gradients over repeated tensors along 'dim', and reduce shape
    // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize').
    // Example:
    //        Size(3, 2)                                             Size(6, 2)
    //                                                             [[v1_0, v1_1],
    //                                                              [v1_2, v1_3],
    //        [[v0, v1],                   repeat(2, 1)             [v1_4, v1_5],
    //         [v2, v3],                  ------------->            [v2_0, v2_1],
    //         [v4, v5]]                                            [v2_2, v2_3],
    //                                                              [v2_4, v2_5]]
    //
    //    input grad (3, 2)            reshape (2, 3, 2)         output grad (6, 2)
    //                                  [[[g1_0, g1_1],            [[g1_0, g1_1],
    //                                    [g1_2, g1_3],             [g1_2, g1_3],
    // [[g1_0+g2_0, g1_1+g2_1],           [g1_4, g1_5]],            [g1_4, g1_5],
    //  [g1_0+g2_0, g1_1+g2_1],                                     [g2_0, g2_1],
    //  [g1_0+g2_0, g1_1+g2_1]]          [[g2_0, g2_1],             [g2_2, g2_3],
    //                                    [g2_2, g2_3],             [g2_4, g2_5]]
    //                                    [g2_4, g2_5]]]
    // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then
    // sum over 'dim+1'. The gradient for input is not correctly aligned with input.
    // Example:
    //     input grad (3, 2)            reshape (3, 2, 2)        output grad (6, 2)
    //                                  [[[g1_0, g1_1],
    //                                    [g1_2, g1_3]],           [[g1_0, g1_1],
    //                                                              [g1_2, g1_3],
    // [[g1_0+g1_2, g1_1+g1_3],          [[g1_4, g1_5],             [g1_4, g1_5],
    //  [g1_4+g2_0, g1_5+g2_1],           [g2_0, g2_1]],            [g2_0, g2_1],
    //  [g2_2+g2_4, g2_3+g2_5]]                                     [g2_2, g2_3],
    //                                   [[g2_2, g2_3],             [g2_4, g2_5]]
    //                                    [g2_4, g2_5]]]
    if (repeat != 1) {
      grad_size.push_back(repeat);
      sum_dims.push_back(grad_size.size() - 1);
    }
    // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1)
    grad_size.push_back(input_shape[dim]);
  }
  // One-time Reshape & Sum
  // Reshape gradient to grad_size:
  //   1. If repeat equals to 1, append input size at that dimension,
  //   2. If repeat is larger than 1, append both repeat and input size at that dimension.
  // Sum over all "repeat" dimensions from sum_dims:
  // Example:
  // Input Size         (2,    3,    4,    5)
  // repeat             [4,    1,    9,    3]
  // output/grad Size   (8,    3,    36,   15)
  // grad_size          [4, 2,    3, 9, 4, 3, 5]
  // sum_dims           [0,          3,    5]

  // When repeat 1 time over all original dimensions, the empty sum_dims will reduce
  // the whole grad tensor into a scalar rather than keeping original dimensions.
  if (!sum_dims.empty()) {
    grad = grad.reshape(grad_size);
    grad = grad.sum(sum_dims);
  }
  return grad;
}

// p1m == 1 - p
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
  if (grad.requires_grad()) {
    // Use autograd-friendly backward if double backward is required
    return grad * (mask.type_as(grad) * (1. / p1m));
  } else {
    return at::_masked_scale(grad, mask, 1. / p1m);
  }
}

// scale == (1 / (1 - prob))
Tensor infinitely_differentiable_native_dropout_backward(const Tensor& grad, const Tensor& mask, double scale) {
  return grad * (mask.type_as(grad) * scale);
}

Tensor native_dropout_double_backward(const Tensor& ggI, const Tensor& grad, const Tensor& mask, double scale) {
  return ggI.type_as(grad) * (mask.type_as(grad) * scale);
}

Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) {
  bool any_tensor_subclass_like = areAnyTensorSubclassLike({grad, input, value});
  if (any_tensor_subclass_like || input.is_cuda()) {
    const auto input_isnan = input.isnan();
    const auto value_isnan = value.isnan();
    const auto& input_and_value_isnan = any_tensor_subclass_like ?
      input_isnan.logical_and(value_isnan) :
      input_isnan.logical_and_(value_isnan);
    const auto mask = (input == value).logical_or_(input_and_value_isnan);
    return mask * (grad / mask.sum());
  } else {
    auto mask = value.isnan().item<bool>() ? input.isnan() : input == value;
    return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum());
  }
}

Tensor evenly_read_jvp(const Tensor& fw_grad, const Tensor & input, const Tensor & value) {
  auto mask = (input == value);
  auto count = mask.sum();
  auto grad_output = fw_grad / count;
  return at::sum(mask * grad_output);
}

static Tensor var_backward(const Tensor & grad, const Tensor & self, int64_t correction) {
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
  return (2.0 / (self.numel() - correction)) * grad * (self - self.mean());
}

Tensor var_backward(Tensor grad, const Tensor& self, c10::optional<IntArrayRef> dim_opt,
    c10::optional<int64_t> correction_opt, bool keepdim) {
  auto correction = correction_opt.value_or(1);
  if (self.dim() == 0 || !dim_opt.has_value()) {
    return var_backward(grad, self, correction);
  }
  auto dim = dim_opt.value();
  if (!keepdim && self.dim() > 1) {
    grad = unsqueeze_multiple(grad, dim, self.sizes().size());
  }
  const int64_t dof = _safe_size(self.sizes(), dim) - correction;
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
  return (2.0 / dof) * grad * (self - self.mean(dim, /*keepdim=*/true));
}

Tensor var_jvp(const Tensor& self_t, const Tensor& self_p, const Tensor& result, c10::optional<IntArrayRef> dim_opt,
    c10::optional<int64_t> correction_opt, bool keepdim) {
  auto correction = correction_opt.value_or(1);
  if (self_p.dim() == 0 || !dim_opt.has_value()) {
    return var_backward(self_t.conj(), self_p, correction).sum().expand_as(result).conj();
  }
  auto dim = dim_opt.value();
  const int64_t dof = _safe_size(self_p.sizes(), dim) - correction;
  return ((2.0 / dof) * self_t.conj() * (self_p - self_p.mean(dim, /*keepdim=*/true))).sum(dim, keepdim).conj();
}

Tensor std_backward(
    const Tensor& result, const Tensor& grad, const Tensor& self,
    c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction, bool keepdim) {
  auto grad_var = (grad / (result * 2)).masked_fill_(result == 0, 0);
  return var_backward(grad_var, self, dim, correction, keepdim);
}

Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) {
  return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim);
}

Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int64_t numel) {
  return grad.expand(sizes) / numel;
}

static Tensor mean_backward(
    const Tensor& grad, const IntArrayRef sizes, int64_t numel,
    c10::optional<IntArrayRef> dim, bool keepdim) {
  if (dim.has_value()) {
    return mean_backward(grad, sizes, *dim, keepdim);
  } else {
    return mean_backward(grad, sizes, numel);
  }
}

Tensor var_std_mean_backward(
    const variable_list& grads, const Tensor& self, const Tensor& r1,
    const Tensor& r2, c10::optional<IntArrayRef> dim,
    c10::optional<int64_t> correction, bool keepdim, bool is_std) {
  Tensor grad;
  if (grads[0].defined()) {
    grad = is_std ? std_backward(r1, grads[0], self, dim, correction, keepdim)
                  : var_backward(grads[0], self, dim, correction, keepdim);
  }
  if (grads[1].defined()) {
    Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel(), dim, keepdim);
    grad = grad.defined() ? grad + mean_grad : mean_grad;
  }
  return grad;
}

Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) {
  int64_t numel = 1;
  for (auto size : sizes) {
    numel *= size;
  }
  auto mask_selected = grad.masked_select(mask);
  auto diff_nelem = numel - mask_selected.numel();
  if (diff_nelem > 0) {
    // because mask_selected returns a 1-d tensor with size of masked elements that are 1,
    // we need to fill out the rest with zeros then reshape back to tensor2's size.
    auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
    mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
  }
  return mask_selected.view(sizes);
}

Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) {
  // Differentiation of the Cholesky decomposition, Iain Murray
  // https://arxiv.org/abs/1602.07527
  // equation 8
  auto input_tangent_ = upper ? input_tangent.mH() : input_tangent;
  auto L_ = upper ? L.mH() : L;

  auto L_inverse = at::linalg_solve_triangular(L_, at::eye(L.size(-1), L.options()), /*upper=*/false);
  auto phi = at::matmul(at::matmul(L_inverse, input_tangent_), L_inverse.mH());
  phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
  auto L_tangent = L_.matmul(phi);
  return upper ? L_tangent.mH() : L_tangent;
}

Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
  // cf. Iain Murray (2016); arXiv 1602.07527
  // This gradient is symmetric, and not triangular.
  // Cholesky additionally assumes that the input is symmetric, which is a subspace of
  // R^{n x n}, and hence the derivative is not well-defined for off-diagonal
  // elements. We resolve this by taking the gradient of the functionally independent
  // elements of the matrix (i.e., the lower triangular portion of the input) and then
  // reflect it on the upper triangular portion, thereby symmetrizing the gradient of
  // the cholesky operation. The motivation behind this choice is that symmetric gradient
  // leads to stable gradient updates, and retains symmetry of the updated matrix if it
  // were updated by a gradient based algorithm.
  if (upper) {
    L = L.mH();
    grad = grad.mH();
  }
  auto L_inverse = at::linalg_solve_triangular(L, at::eye(L.size(-1), L.options()), /*upper=*/false);
  auto phi = at::matmul(L.mH(), grad);
  phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);

  auto grad_input = at::matmul(at::matmul(L_inverse.mH(), phi), L_inverse);
  return grad_input.add(grad_input.mH()).mul_(0.5);  // Symmetrizing the gradient
}

Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
  Tensor grad_L;
  if (grad.defined()) {
    Tensor common_term = grad + grad.mT();
    common_term = at::matmul(inverse, at::matmul(common_term, inverse));
    if (upper) {
      grad_L = -at::matmul(L, common_term);
    } else {
      grad_L = -at::matmul(common_term, L);
    }
  } else {
    grad_L = at::zeros({1}, L.options()).expand_as(L);
  }
  return grad_L;
}

// The formula for forward AD is adapted from
//
// Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses and Nonlinear
// Least Squares Problems Whose Variables Separate."
// SIAM Journal on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036
//
// We present a short derivation below:
// Let Ap := pinv(A), then Ap is the unique matrix such that
//
// Ap A Ap = Ap [1]
// A Ap A = A   [2]
//
// By differentiating [1] we get:
//
// dAp = dAp A Ap + Ap dA Ap + Ap A dAp [3]
//
// In the rhs of [3] the products involving dAp could be expressed as products of
// Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH().
// To prove that, note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by
// taking the product between the SVD decompositions of A and Ap.
// Consider the conjugate-tranposed [2]:
// (A Ap A)^H = A^H (A Ap) = A^H. By differentiating it we get:
// dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the left by Ap^H
// and using Ap^H A^H = (A Ap)^H = A Ap:
// Ap^H dA^H A Ap + A Ap dA Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left
// by Ap and by applying [1] and [2] repeatedly until impossible we get:
// Ap Ap^H dA^H A Ap + Ap dA Ap + Ap A dAp = Ap Ap^H dA^H. By rearranging the terms:
//
// Ap A dAp = -Ap dA Ap + Ap Ap^H dA^H (I - A Ap) [4],
// which is one of the summands in [3].
//
// Similar, by differentiating the transpose-conjugated [2] written differently, i.e.
// (A Ap A)^H = Ap A A^H = A^H we will get an expression for dAp A Ap, which is
//
// dAp A Ap = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap [5].
//
// By plugging in [4] and [5] into [3] we get the forward AD formula for pinv:
//
// dAp = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap + Ap Ap^H dA^H (I - A Ap).
Tensor pinv_jvp(
  const Tensor& A,
  const Tensor& pinvA,
  const Tensor& dA
) {
  at::NoTF32Guard disable_tf32;
  auto m = A.size(-2);
  auto n = A.size(-1);
  auto dAh = dA.mH();
  auto pinvAh = pinvA.mH();
  // optimization to produce matrices of the smallest dimension
  if (m <= n) {
    auto K = pinvAh.matmul(dAh);
    return pinvA.matmul(K - K.mH() - K.matmul(A.matmul(pinvA)))
         + (dAh - pinvA.matmul(A.matmul(dAh))).matmul(pinvAh.matmul(pinvA));
  }
  else {
    auto K = pinvA.matmul(dA);
    auto Kh = K.mH();
    return (Kh - K - pinvA.matmul(A).matmul(Kh)).matmul(pinvA)
         + (pinvA.matmul(pinvAh)).matmul(dAh - (dAh.matmul(A)).matmul(pinvA));
  }
}

Tensor pinv_backward(
  const Tensor& grad,
  const Tensor& pinvA,
  const Tensor& A
) {
  at::NoTF32Guard disable_tf32;
  auto m = A.size(-2);
  auto n = A.size(-1);
  auto pinvAh = pinvA.mH();
  auto gradh = grad.mH();
  // optimization to produce matrices of the smallest dimension
  if (m <= n) {
    auto K = gradh.matmul(pinvA);
    auto KpinvAh = K.matmul(pinvAh);
    return - (pinvA.matmul(K)).mH()
           + KpinvAh - (A.matmul(pinvA)).matmul(KpinvAh)
           + (pinvAh.matmul(pinvA)).matmul(gradh - K.matmul(A));
  }
  else {
    auto K = pinvA.matmul(gradh);
    auto pinvAhK = pinvAh.matmul(K);
    return - (K.matmul(pinvA)).mH()
           + (gradh - A.matmul(K)).matmul(pinvA).matmul(pinvAh)
           + pinvAhK - pinvAhK.matmul(pinvA).matmul(A);
  }
}

Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
                                 IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
  dim = at::maybe_wrap_dim(dim, sizes.size());

  // it's possible some of the grads are not defined (represents tensors of all 0s).
  // Since at::cat can't handle those, let's define them
  std::vector<Tensor> grads_all_defined(grads.size());
  for (const auto j : c10::irange(grads.size())) {
    if (grads[j].defined()) {
      grads_all_defined[j] = grads[j];
    } else {
      auto length = split_sizes[j];
      auto grad_size = sizes.vec();
      grad_size[dim] = length;
      grads_all_defined[j] = at::zeros(grad_size, options);
    }
  }

  auto ret =  at::cat(grads_all_defined, dim);
  return ret;
}

Tensor split_backward(const std::vector<torch::autograd::Variable> &grads,
                      int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) {
  dim = at::maybe_wrap_dim(dim, sizes.size());
  int64_t dim_size = sizes[dim];
  int64_t num_splits = grads.size();
  std::vector<int64_t> split_sizes(num_splits, split_size);
  split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size);
  return split_with_sizes_backward(grads, split_sizes, dim, sizes, options);
}

Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) {
  AT_ASSERT(indices.dim() >= dim);
  auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
  size.push_back(-1);
  auto indices_view = indices.view(size);
  const auto memory_format = indices.suggest_memory_format();
  return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes());
}

Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) {
  auto& gO = grad_output;
  auto input_size = input.size(dim) / 2;
  auto first_half = input.narrow(dim, 0, input_size);
  auto second_half = input.narrow(dim, input_size, input_size);
  auto sig_second_half = second_half.sigmoid();
  auto one_sub_sig_second_half = 1 - sig_second_half;
  auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;

  auto ggI_first_half = grad.narrow(dim, 0, input_size);
  auto ggI_second_half = grad.narrow(dim, input_size, input_size);
  auto ggI_second_half_times_first_half = ggI_second_half * first_half;

  auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
  auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig;
  auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig;
  return at::cat({gI_first_half, gI_second_half}, dim);
}

Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) {
  if (dim < 0) dim += input.dim();
  auto sizes = input.sizes().vec();
  sizes[dim] /= 2;
  auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
  return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
}

Tensor infinitely_differentiable_silu_backward(
    const Tensor& grad_output,
    const Tensor& input) {
  const Tensor sigmoid = input.sigmoid();
  return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
}

Tensor infinitely_differentiable_mish_backward(
    const Tensor& grad_output,
    const Tensor& input) {
  const Tensor sigmoid = input.sigmoid();
  const Tensor softplus = input.exp().log1p();
  const Tensor tanh_softplus = softplus.tanh();
  return grad_output * (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus));
}

Tensor infinitely_differentiable_logit_backward(
    const Tensor& grad,
    const Tensor& self,
    c10::optional<double> eps) {
  if (eps) {
    const double lo = eps.value();
    const double hi = 1.0 - lo;
    return at::where(
        at::logical_and(self >= lo, self <= hi),
        grad / (self * (1.0 - self)),
        at::zeros({}, self.options()));
  } else {
    return at::where(
        at::logical_and(self >= 0.0, self <= 1.0),
        grad / (self * (1.0 - self)),
        at::empty({}, self.options())
            .fill_(std::numeric_limits<double>::quiet_NaN()));
  }
}

Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) {
  auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target);
  if (reduction == at::Reduction::Mean) {
    return result.mean();
  } else if (reduction == at::Reduction::Sum) {
    return result.sum();
  }
  return result;
}

// Compute derivatives for targets.
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) {
  Tensor grad_target;
  if (!log_target) {
    grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
  }
  else {
    grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp()));
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.numel());
  }

  return grad_target;
}

Tensor binary_cross_entropy_target_backward(
  const Tensor& grad,
  const Tensor& self,
  const Tensor& target,
  const c10::optional<Tensor>& weight,
  int64_t reduction) {
  auto grad_target = (1. - self).log_().sub_(self.log());
  if (!areAnyTensorSubclassLike({grad})) {
    grad_target.mul_(grad);
  } else {
    grad_target = grad_target * grad;
  }

  if (isDefined(weight)) {
    if (!isTensorSubclassLike(weight.value())) {
      grad_target.mul_(weight.value());
    } else {
      grad_target = grad_target * weight.value();
    }
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.numel());
  }

  return grad_target;
}


Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional<Tensor>& weight, const c10::optional<Tensor>& pos_weight, int64_t reduction) {
  Tensor grad_target;
  if (isDefined(pos_weight)) {
    if (!areAnyTensorSubclassLike({*pos_weight, grad_output})) {
      grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output);
    } else {
      grad_target = (1. - self.sigmoid()).log_().sub(pos_weight->mul(self.sigmoid().log_())).mul(grad_output);
    }
  } else {
    grad_target = self.mul(-grad_output);
  }

  if (isDefined(weight)) {
    if (!isTensorSubclassLike(*weight)) {
      grad_target.mul_(*weight);
    } else {
      grad_target = grad_target.mul(*weight);
    }
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.numel());
  }

  return grad_target;
}

Tensor binary_cross_entropy_with_logits_jvp(const Tensor& input_t, const Tensor& target_t, const Tensor& input_p, const Tensor& target_p, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& pos_weight_opt, int64_t reduction) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;
  const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();});

  Tensor grad_input;
  Tensor grad_target;

  if (pos_weight.defined()) {
    // pos_weight need to be broadcasted, thus mul(target) is not inplace.
    auto t = pos_weight.mul(target_p);
    grad_input = input_t.mul(t.add(1).sub_(target_p).mul_(input_p.sigmoid()).sub_(t));
  } else {
    grad_input = input_t.mul(input_p.sigmoid() - target_p);
  }

  if (pos_weight.defined()) {
    grad_target = target_t.mul((1. - input_p.sigmoid()).log_().sub_(pos_weight.mul(input_p.sigmoid().log_())));
  } else {
    grad_target = -target_t.mul(input_p);
  }

  if (weight.defined()) {
    grad_input.mul_(weight);
    grad_target.mul_(weight);
  }
  return apply_loss_reduction(grad_target + grad_input, reduction);
}

Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
  auto z = input.sigmoid();
  return grad * (z - 1) * z;
}

Tensor softmax_double_backward(const Tensor& grad, const Tensor& grad_output, int dim, const Tensor& output) {
  return grad_output * grad - (output * grad_output).sum(dim, true) * grad - grad_output * (output * grad).sum(dim, true);
}

// NOTE: [How to write vmap-compatible backward formulas]
//
// See NOTE: [vmap-incompatible in-place operations] for what it means for an
// in-place operation to be incompatible with vmap.
//
// If an in-place operation used in a backward formula is vmap-incompatible,
// then as developers we have the following options:
//
// - If the in-place operation directly followed the creation of a tensor with
//   a factory function like at::zeros(...), we should replace the factory with a
//   corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
//   propagates the batch dims to the resulting tensor.
//   For example:
//     Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
//     After:  grad.new_zeros(input.sizes()).copy_(grad)
//
// - If the in-place operation followed some sequence of operations, if the
//   we want to be able to vmap over the backward formula as-is (this is
//   usually the case for simple (<15loc) backward formulas), then use
//   areAnyTensorSubclassLike  to guard the operation. For example:
//             c = a * b
//     Before: c.mul_(grad)
//     After:  c = !areAnyTensorSubclassLike({c, grad}) ? c.mul_(grad) : c * grad
//
// - If we don't want to vmap directly over the backward formula (e.g., if the
//   backward formula is too complicated or has a lot of vmap-incompatible
//   operations, then register the backward formula as an operator and eventually
//   write a batching rule for it.

Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
  auto eps = 1e-12;
  auto inp_pl_eps = input + eps;
  auto one_m_inp_pl_eps = 1 - input + eps;
  // gradient wrt input
  auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
  if (!areAnyTensorSubclassLike({gI, grad})) {
    gI *= (grad * grad_output);
  } else {
    gI = gI * (grad * grad_output);
  }

  if (isDefined(weight)) {
    if (!isTensorSubclassLike(*weight)) {
      gI *= *weight;
    } else {
      gI = gI.mul(*weight);
    }
  }
  if (reduction == at::Reduction::Mean) {
    return gI / input.numel();
  }

  return gI;
}

Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional<Tensor>& weight, int64_t reduction) {
  auto eps = 1e-12;
  // gradient wrt grad_output
  auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
  if (!areAnyTensorSubclassLike({ggO, grad})) {
    ggO *= grad;
  } else {
    ggO = ggO * grad;
  }

  if (isDefined(weight)) {
    if (!isTensorSubclassLike(*weight)) {
      ggO *= *weight;
    } else {
      ggO = ggO.mul(*weight);
    }
  }
  if (reduction == at::Reduction::Mean) {
    return ggO / input.numel();
  } else if (reduction == at::Reduction::Sum) {
    return ggO.sum();
  }
  return ggO;
}

Tensor l1_loss_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & self, const Tensor & other, int64_t reduction) {
  if (!self.is_complex()) {
    return at::zeros_like(grad);
  } else {
    auto diff = self - other;
    auto output = grad_output * sgn_backward(diff.sgn(), grad, diff);
    if (reduction == at::Reduction::Mean) {
      output /= self.numel();
    }
    return output;
  }
}

Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
  auto output = at::l1_loss_backward(grad.conj(), input, target, at::Reduction::None);
  if (reduction == at::Reduction::Mean) {
    return output.mean();
  } else if (reduction == at::Reduction::Sum) {
    return output.sum();
  }
  return handle_r_to_c(grad_output, output);
}

Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
  // special case to protect against a divide-by-zero.
  if (beta == 0) {
      return at::zeros(grad.sizes(), grad.options());
  }
  auto d = (input - target).abs();
  auto grad_input = grad * (d < beta).type_as(grad) / beta;
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
  if (reduction == at::Reduction::None) {
    return smooth_l1_loss_backward(grad, input, target, reduction, beta);
  }
  auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta);
  return (r * grad).sum();
}

Tensor huber_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double delta) {
  auto d = (input - target).abs();
  auto grad_input = grad * (d < delta);
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor huber_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double delta) {
  if (reduction == at::Reduction::None) {
    return huber_loss_backward(grad, input, target, reduction, delta);
  }
  auto r = huber_loss_backward(ones_like(grad_output), input, target, reduction, delta);
  return (r * grad).sum();
}

Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) {
  auto grad_input = 2 * grad;
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
  if (reduction == at::Reduction::None) {
    return mse_loss_backward(grad, input, target, reduction);
  }
  auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction);
  return (r * grad).sum();
}

Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
  auto z = (input * -target).exp();
  auto zplus1 = z + 1;
  auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.numel();
  }
  return grad_input;
}

Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
  if (reduction == at::Reduction::None) {
    return soft_margin_loss_backward(grad, input, target, reduction);
  }
  auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction);
  return (r * grad).sum();
}

Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, const Scalar& beta, const Scalar& threshold) {
  auto x = (input * beta);
  return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta;
}


// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
//
// `storage_offset` is ignored for simplicity in this note. If you just want the
// full algorithm without explanation, scroll down to bottom of this note.
//
// Implementing the backward of as_strided is tricky because you have to deal
// with mappings that map one memory location to multiple indices, i.e., the
// output tensor has multiple indices pointing to **overlapping** memory
// addresses. This can happen in all in all sorts of weird cases. For example,
//
//   x = torch.randn(15)
//   x.as_strided([3, 3], [1, 0])  # "expand" case
//   x.as_strided([3, 3], [2, 1])  # "size too large" case
//   x.as_strided([3, 2], [3, 6])  # res[2, 0] points to 2*3 + 0*6 = 6
//                                 # res[0, 1] points to 0*3 + 1*6 = 6
//
// Here is the general strategy we apply in implementing as_strided backward:
//   0. ??? (optimization step. we will talk about this later)
//   1. Create some underlying flattened tensor as if it is the base tensor
//      representing the contiguous memory storage for both input and output.
//   2. Use the output geometry to scatter (or index_add) the gradients into
//      this storage tensor.
//   3. ??? (fix for input tensor with overlapping memory. we will talk about
//           this later)
//   4. Return the as_strided view of the storage tensor using input geometry.
//
// In step (2), if the output tensor does't have overlapping memory, we can
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
// otherwise, we must use `index_add` as gradients at different indices may need
// to be summed to a single location.
//
// For example, in this case:
//
//   x = torch.randn(3)
//   y = x.as_strided([3, 3], [1, 0])  # "expand" case
//                                     # size   [ 3, 3]
//                                     # stride [ 1, 0]
//   y.backward()  # step (1): contiguous storagte tensor `s` of size 3, which
//                             is large enough to be used as underlying storage
//                             for `x` and `y`.
//                               s = [ 0, 0, 0]
//                 # step (2): since `y` has overlapping memory, index_add grad
//                             into `s` basing on `y`'s geometry, i.e.,
//                             s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
//                               s = [ 3, 3, 3]
//                 # step (4): as_strided view `s` using `x`'s geometry
//                               s = [ 3, 3, 3]
//                               grad_input = s.as_strided(x.size(), x.stride())
//                                          = s.as_strided([3], [1])
//                                          = [ 3, 3, 3]
//
// This is exactly what we would get if using `expand`. However, here the input
// tensor doesn't have overlapping memory. If it does, we must add an extra step
// before (4). Considering this case:
//
//   t = torch.randn(3)
//   x = t.expand(3, 3)            # input with overlapping memory
//                                 # size   [3, 3]
//                                 # stride [0, 1]
//   y = x.as_strided([1], [1])    # contiguous output
//                                 # size   [1]
//                                 # stride [1]
//   y.backward()  # step (1): contiguous storage tensor `s` of size 3, which
//                             is large enough to be used as underlying storage
//                             for `x` and `y`.
//                               s = [ 0, 0, 0]
//                 # step (2): scatter grad into `s` basing on `y`'s geometry
//                               s = [ 1, 0, 0]
//                 # step (4): as_strided view `s` using `x`'s geometry
//                               s = [ 1, 0, 0]
//                               grad_input = s.as_strided([3, 3], [0, 1])
//                                          = s.as_strided([3, 3], [0, 1])
//                                          = [[ 1, 0, 0],
//                                             [ 1, 0, 0],
//                                             [ 1, 0, 0]]
// Is this result correct?
//
// `x.as_strided([1], [1])` call is obviously equivalent with
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
// case, indexing `x` at any index in first column is also equivalent, and
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
// an `x.size(1)`-times difference between these gradients computed from other
// PyTorch ops and the gradient we got from as_strided.
//
// You might conclude that the gradients from as_strided is wrong. However,
// let's first see why they are actually reasonable. Consider the pointwise
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
// a `delta` change in the same memory location, and then `y` will change by
// `delta`. So one can say the gradient should be exactly 1 at the first column,
// as given by our above procedure.
//
// In the above computation of numerical gradients, they only match the
// analytical results because strides and memory locations are considered in the
// forward pass, i.e., this op (including both forward and backward) is
// layout-aware.
//
// However, in PyTorch, most (probably all) other ops (forward and backward) are
// layout-agnostic. E.g.,
//
//   t = torch.randn(1)
//   x = t.expand(2)
//   y = x.sum()
//   y.backward()
//
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
//
//   gy = 1
//   gx = [ 1, 1]  # SumBackward:    torch.ones_like(x)
//   gt = [ 2]     # ExpandBackward: gx.sum()
//
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
// the gradients, if strides are taken into consideration, should be 2.
//
// Layout-aware autograd should give you
//
//   gy = 1
//   gx = [ 2, 2]  # Because the backward considers the fact that the input `x`
//                 # is already expanded.
//   gt = [ 2]     # Layout-aware backward of expand is just a slicing because
//                 # the previous backward should have already taken care of
//                 # strides and made sure that gradients are the same along the
//                 # expanded dimension.
//
// As shown above, these two types are not compatible. Therefore, we must either
// make as_strided layout-agnostic, or make all other ops layout-aware.
//
// It is difficult to support layout-aware autograd (at least in the current
// codebase structure), because it would mean
//   1. storing tensor geometries of every input tensor for backward
//   2. depending on input geometry, the gradient computed from backward change
//   3. ideally enforcing gradient of T to always have same strides as T
// (although these two methods only differ when it comes to overlapping memory)
//
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
// giving the same output regardless of the input layout. We consider
// `input.stride()` as a separate independent fixed argument `input_stride`.
// Then, `as_strided(input, size, stride)` can be thought of as:
//   1. "Scatter" each value of `input` into a "storage" using storage location
//      computed from the value's index in `input`, `input.size()` and
//      `input_stride`, but if N values end up in the same location, the value
//      is average of those N values (they will be the same value anyways).
//
//      Formal description:
//        Denote the set of all input indices that pointing to the same storage
//        location `storage[n]` as `S(n)`, i.e.,
//
//            S(n) = { index : <index, input_stride> == n, index is valid given input.size() },
//
//        where `<x, y>` is the dot product between `x` and `y`.
//
//        Then, the process is:
//
//            storage[n] = Avg { S(n) }
//
//        Note that all values in `S(n)` are the same (they point to the same
//        memory location anyways, so this step doesn't change anything, but
//        effectively avoids having the denpendency on the layout of `input`.
//        I.e., the result holds fixed regardless of the layout of `input`, as
//        long as `input_stride` is fixed.
//
//      NOTE: for forward pass, we can equivalently simply selet any one of
//            `S(n)` as `storage[n]`. However, cosnidering this as an average
//            operation makes backward easier (so all values in set
//            `{ grad_input[i] : i in S(n) }` are the same, and it can use the
//            same geometry as input).
//   2. As usual, return the as_strided view of `storage` using required output
//      `size` and `stride`.
//
// To backward through this layout-agnostic version, we simply add the following
// step:
//   .... (scatter gradients into the storage tensor using output geometry)
//   3. For all storage location n, `storage[n] /= |S(n)|`.
//   .... (return as_strided view of the storage tensor using input geometry)
//
// Finally, we note that these general operations are expensive, so we apply the
// following optimizations:
//   Add step (0): For all output dimension `d` with output stride 0, sum the
//                 gradients along dimension `d` (don't keepdim), and remove
//                 dimension `d` from output size and stride.
//                 (An optimization for "expand" cases so we may avoid step (3))
//  Only apply step (3) when input tensor has overlapping memory.
//
// FULL ALGORITHM:
//   0. For all output dimension `d` with output stride 0, sum the gradients
//       along dimension `d` (don't keepdim), and remove dimension `d` from
//       output size and stride.
//   1. Create some underlying flattened tensor as if it is the base tensor
//      representing the contiguous memory storage for both input and output.
//   2. Use the output geometry to scatter (or index_add) the gradients into
//      this storage tensor `storage`.
//   3. If input tensor has overlapping memory,
//      For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
//      number of indices in input geometry pointing to the same storage
//      location `i` (i.e., `|S(i)|` in equations above).
//   4. Return the as_strided view of the storage tensor using input geometry.
//
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
// roughly detech overlapping memory.


// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
//
// Checking memory overlap within a strided tensor is the special case of
// detecting memory overlap of two strided tensors, where the two tensors start
// at the same memory address. The later is HARD (see #8212).
//
// But even this special case isn't simple. This note describes a check for a
// even more constrained simple case where we can be certain that there is no
// overlap.
//
// The checking algorithm can be described as:
//   0. Return [ pass check ] if any dimension has size 0
//   1. Ignore all dimensions that have size 1
//   2. If no remaining dimensions, return [ pass check ]
//   3. Sort the remaining dimensions according to the strides decreasingly
//   4. Check that for each dimension k,
//
//           stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
//
//      That is equivalent to, after reordering the dimensions so strides are
//      in decreasing order, checking that stride of each dimension is larger
//      than the maximum memory offset in a slice at that dimension.
//
// Obviously this check passes for contiguous tensors ( the dimensions will be
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
// than RHS ). Similarly, the check passes for tensors contiguous in all but
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
// exactly stride[-1] larger than RHS. (*)
//
// We will show that these view operations, including all our view operations
// *except for* general as_strided and unfold, also preserve this invariant:
//
//  alias:      Obviously preserves
//
//  expand:     All changed dimensions are removed in step (1)
//
//  view:       Consider the input dimensions as grouped into consecutive
//              dimension "blocks", where dimensions are contiguous in each one.
//              one. view only works when the output dimensions can also be
//              grouped into the same consecutive blocks of same ordering.
//
//              NB: this means that the number of elements and stride of the
//                  last dimension in each block is the same in input and
//                  output. (**)
//
//              Notation:
//                Consider a single such block B,
//                    ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ...
//                                start--^^^^                  ^^^^^^^^^^^^--end
//                Each B[i] denotes a dimension index such that B[i] = B[0] + i.
//
//              We first show that in a tensor (i.e., input) satisfies the
//              invariant, after sorting, the dimensions within each block
//              still remain consecutive. (***)
//
//                After removing dimensions of size 1, the dimensions within a
//                block is already sorted by strides in descending order. So
//                sorting all dimensions will not change the relative ordering
//                among them.
//
//                Assume that some block B is not consecutive after sorting,
//                i.e., there exists a dimension d between B[0] and B[-1] in
//                sorted order.
//
//                By (*), we know that
//                       stride[B[0]]
//                    =  \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]]
//                    <  \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] + stride[d]
//                    <= \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d]
//                    <= \sum{j > B[0]} (size[j]    - 1) * stride[j],
//
//                where the first <   comes from sorting and
//                      the second <= comes from the fact that dimension d
//                                               exists after step (1) and
//                                               thus must have size greater
//                                               than 1
//                      the third  <= comes from the fact that each term in
//                                               the sum is non-negative
//
//                Then we have a countradiction as the invariant must not be
//                satisfied at B[0]. So the original proposition is true.
//
//              Now that we established the above claim (***), we consider the
//              view operation as first sorting the dimensions (i.e., blocks),
//              apply the original view (since it only cares dimensions being
//              consecutive and contiguous withtin each block), and then undo
//              the sort.
//
//              Consider a single block B in the output,
//                  ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
//                    start--^^^^                  ^^^^^^^^^^^^--end
//
//              By (*), we know that for all i
//                  stride[i] = stride[B[-1]] +
//                                \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]]
//
//              Then the invariant is obviously satisfied at every dimension
//              in this block if it is satisfied at dimnesion B[-1]. It only
//              remains to show that it is satisfied at the last dimension in
//              each block.
//
//              Since the same blocks are present in both input and output
//              with the same ordering, we will abuse the notation in the
//              following statements.
//
//              By (*), we know that the following holds for both input and
//              output, for any block B:
//                    \sum_{i > B[-1]} (size[i] - 1) * stride[i]
//                  = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]]
//                  = \sum_{block B' after B} numel(B') * stride[B'[-1]].
//                    ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
//              By (**), we know that, this quantity in the above equation
//              remains the same in input and output. So both
//                  \sum_{i > B[-1]} (size[i] - 1) * stride[i]
//              and
//                  stride[B[-1]]
//              are the same in input and output.
//
//              These two quantities are exactly the LHS and RHS of the
//              invariant inequality. Since by assumption the invariant is
//              satisfied in input at B[-1], it is also satisfied in output at
//              B[-1]. This concludes the proof.
//
//  squeeze:    Special case of view
//
//  unsqueeze:  Special case of view
//
//  slice:      Consider slicing dimension i with step = k >= 1.
//
//              Let stride' and size' be the output strides and sizes. We have
//
//                  stride'[i] = k * stride[i]
//                  size'[i] <= floor(size[i] / k)
//
//              If size'[i] = 1, invariant is obviously satisfied as we are
//              just removing a dimension (afte step (1)).
//
//              Assume size'[i] > 1.
//
//              By assumption, the invariant is satisfied at every dimension
//              in input.
//
//              For any dimension j, if stride[j] > stride[i], we have
//                  stride'[j] =  stride[j]
//                             >  (size[i] - 1) * stride[i]
//                             =  (size[i] / k * k - 1) * k * stride[i] / k
//                             =  (size[i] / k - 1 / k) * stride'[i]
//                             >= (size'[i]    - 1 / k) * stride'[i]
//                             >= stride'[i].
//
//              If stride[j] < stride[i], we have
//                  stride'[j] = stride[j] < stride[i] <= stride'[i].
//
//              So the sorting order remains unchanged after slice.
//
//              Since
//                     (size'[i] - 1) * stride'[i]
//                  =  (floor(size[i] / k) - 1) * k * stride[i]
//                  <= (size[i] / k - 1) * k * stride[i]
//                  =  (size[i] - k) * stride[i]
//                  <= (size[i] - 1) * * stride[i],
//              the term from this dimension i in the invariant inequality at
//              other dimensions can only decrease after slice. So the
//              invariant is preserved.
//
//  narrow:     Special case of slice
//
//  select:     narrow + squeeze
//
//  permute:    Sorting makes permutation of dimensions irrelevant
//
//  transpose:  Sorting makes swapping dimensions irrelevant
//
//  diagonal:   Effectively merging two dimensions i and j into a new
//              dimension k s.t.
//                  stride'[k] =  stride[i] + stride[j]
//                  size'[k]   <= min(size[i], size[j]),
//              where stride and size are on the input, and stride' and size'
//              are on the output.
//
//              Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
//              then this is unsqueeze on that dimension.
//
//              WLOG, say stride[i] >= stride[j].
//
//              Each dimension d in input with stride[d] > stride[j] has
//                  stride'[d] =  stride[d]
//                             >  (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j]
//                             >= stride[i] + stride[j]
//                             =  stride[k].
//              So, considering the sorted dimensions, this is effectively
//              removing i, and replacing j with k.
//
//              For dimensions d with stride[i] < stride[d] < stride[j], the
//              term from dimension i is removed in the invariant inequality.
//              For dimensions d with stride[d] > stride[j], we have
//                     (size'[k] - 1) * stride'[k]
//                  <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
//                  <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
//              so the term from i and j in the invariant can only decrease.
//
//              So this is generally relaxing the constraint, and thus it
//              preserves it.

// This implements steps (2)~(4) of the algorithm in
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// Helper for as_strided_backward
static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) {
  if (sizes.size() > 0) {
    std::vector<std::size_t> argsort(sizes.size());
    std::iota(argsort.begin(), argsort.end(), 0);
    std::sort(argsort.begin(), argsort.end(),
        [&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; });

    int64_t max_index_in_slice = 0;
    for (auto i : argsort) {
      auto stride_ = strides[i];
      if (stride_ <= max_index_in_slice) {
        return true;
      }
      max_index_in_slice += stride_ * (sizes[i] - 1);
    }
  }
  return false;
}

// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset
// Helper for as_strided_backward
static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
  int64_t storage_size = storage_offset + 1;
  int64_t dim = sizes.size();
  for(const auto i : c10::irange(dim)) {
    auto size_i = sizes[i];
    if (size_i == 0) {
      return storage_offset;
    }
    storage_size += (size_i - 1) * strides[i];
  }
  return storage_size;
}

// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_) {
  // For output geometry,
  //   check for size 0 dimensions,
  //   skip size 1 dimensions,
  //   reduce grad on expanded dims (stride=0, size>1)
  // Step (0)     for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
  // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on output geometry
  auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset());
  auto odim = grad.dim();
  std::vector<int64_t> out_sizes_, out_strides_;
  out_sizes_.reserve(odim);
  out_strides_.reserve(odim);
  for (int64_t i = odim - 1; i >= 0; i--) {
    auto size_i = sizes[i];
    auto stride_i = strides[i];
    if (size_i == 0) {
      return at::zeros(input_geometry.sizes(), grad.options());
    } else if (size_i == 1) {
      grad = grad.squeeze(i);
    } else if (stride_i == 0) {
      grad = grad.sum(i, false);
    } else {
      out_sizes_.insert(out_sizes_.begin(), size_i);
      out_strides_.insert(out_strides_.begin(), stride_i);
    }
  }
  // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on output geometry
  auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);

  // For input geometry,
  //   check for size 0 dimensions,
  //   skip size 1 dimensions,
  // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on input geometry
  auto idim = input_geometry.dim();
  IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides();
  std::vector<int64_t> inp_sizes_, inp_strides_;
  inp_sizes_.reserve(idim);
  inp_strides_.reserve(idim);
  for (int64_t i = idim - 1; i >= 0; i--) {
    auto size_i = inp_sizes[i];
    auto stride_i = inp_strides[i];
    if (size_i == 0) {
      return at::zeros(input_geometry.sizes(), grad.options());
    } else if (size_i != 1) {
      inp_sizes_.insert(inp_sizes_.begin(), size_i);
      inp_strides_.insert(inp_strides_.begin(), stride_i);
    }
  }
  // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
  //              on input geometry
  auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);


  // Rest of this function implements
  // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
  // TODO: Raise if not all output values are visible in input geometry.
  //       Technically speaking, if you treat those values as constants, not
  //       raising is fine, and mathematically correct. However, these values
  //       really are contained in some base tensor, and by treating them as
  //       constants we are ignoring this tight dependency. Therefore, it is
  //       more sensible to raise here.

  // Step (1): create underlying tensor as "storage"
  auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset);
  auto inp_effective_offset = input_geometry.storage_offset() - shared_offset;
  auto out_effective_offset = storage_offset - shared_offset;
  auto base_size = std::max(
    _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
    _min_storage_size(out_sizes_, out_strides_, out_effective_offset)
  );
  auto storage = grad.new_zeros({base_size});

  // prepare indices tensor if we will do index_add_ later
  c10::optional<at::Tensor> flatten_full_indices;
  if (inp_maybe_overlap || out_maybe_overlap) {
    flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
  }

  // Step (2): use output geometry to scatter gradients into storage
  if (out_maybe_overlap) {
    auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset);
    storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
  } else {
    // assume that new tensors have 0 storage offset
    storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad);
  }

  // Step (3): if input tensor has overlapping memory, divide scattered gradient
  //           at storage[i] by the number of times i shows up in input geometry
  if (inp_maybe_overlap) {
    auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1);
    count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
    storage.div_(count); // this will give nan outside visible range
  }
  // Step (4): return as_strided view of the storage tensor with input geometry
  return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
}

std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
  if (!grad.defined()) {
    return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
  }
  auto recip = (self * self + other * other).reciprocal();
  return std::tuple<Tensor,Tensor>{
            output_mask[0] ? grad * other * recip : Tensor(),
            output_mask[1] ? grad * -self * recip : Tensor() };
}

// TODO: Seriously consider writing the derivative formulas for
// each output separately; there is not all that much sharing
// of computation going on here.
std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
    const Tensor & grad_grad_input,
    const Tensor & grad_grad_weight,
    const Tensor & grad_out,
    const Tensor & input_,
    const Tensor & weight_) {

  if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) {
    return std::tuple<Tensor, Tensor, Tensor>(Tensor(), Tensor(), Tensor());
  }
    auto input = input_.contiguous();
    auto weight = weight_.contiguous();

  // Zero-fill undefined grads (TODO: do this more efficiently)
  auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

  auto positive_mask = (input > 0).type_as(ggI);
  auto nonpositive_mask = (input <= 0).type_as(ggW);

  // Explanation: Let input be i, weight be w, grad_output be gO.
  // f(i, w) = i      if i > 0
  //         = w * i  if i <= 0
  // gI = df/di * gO  = gO      if i > 0    gW = df/dw * gO = 0       if i > 0
  //                  = gO * w  if i <= 0                   = gO * i  if i <= 0
  // The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly.

  if (weight.numel() == 1) {
      // from PReLU.forward: num_parameters == 0 is used indicate that a
      // single weight is shared among all input channels.

      // this is a little tricky because PReLU currently doesn't take a shape so the weight may be
      // 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable
      // and tensor are merged).  So, use weight and ggW as 0-dim in this case.
      bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1);
      auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight;
      auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW;

      auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input);
      auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input);
      return std::tuple<Tensor, Tensor, Tensor>(
                ggO,
                ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask,
                (ggI * gO * nonpositive_mask).sum().expand_as(weight)
          );
  } else {
      // Expand ggW to match size of ggI; a simple expand doesn't work because
      // ggW is the size of the input channel (dim==1 unless there is only 1 dimension).  For example,
      // let ggI be size (3,4,5,6,7) and ggW be size (4).  Then we unsqueeze ggW to be size (4,1,1,1)
      // so the expand succeeds.
      auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0);
      auto ggW_expanded = ggW;
      for(const auto i : c10::irange(dims_to_unsqueeze)) {
          (void)i; // Suppress unused variable warning
          ggW_expanded = ggW_expanded.unsqueeze(1);
      }
      ggW_expanded = ggW_expanded.expand_as(ggI);

      auto gI = ggW_expanded * gO * nonpositive_mask;

      auto gW = ggI * gO * nonpositive_mask;
      if (input.dim() > 1) {
          gW = gW.sum(0);
      }
      while (gW.dim() > 1) {
          gW = gW.sum(1);
      }

      Tensor ggO;
      if (gO.requires_grad()) {
          // expand weight as input as in ggW/ggI above
          auto weight_expanded = weight;
          for(const auto i : c10::irange(dims_to_unsqueeze)) {
              (void)i; // Suppress unused variable warning
              weight_expanded = weight_expanded.unsqueeze(1);
          }
          weight_expanded = weight_expanded.expand_as(input);

          auto mask = positive_mask + nonpositive_mask * weight_expanded;
          ggO = ggI * mask + ggW_expanded * nonpositive_mask * input;
      }
      return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
  }
}

Tensor elu_double_backward(
    const Tensor& grad,
    const Tensor& grad_output,
    const Scalar& alpha,
    const Scalar& scale,
    const Scalar& input_scale,
    bool is_result,
    const Tensor& self_or_result) {

    if (is_result) {
      return grad * grad_output * input_scale * (self_or_result < 0).type_as(grad);
    } else {
      return at::elu_backward(grad * grad_output * input_scale, alpha, scale, input_scale, is_result, self_or_result) * (self_or_result < 0).type_as(grad);
    }
}

Tensor slice_backward_wrapper(
    const at::Tensor& grad,
    const c10::IntArrayRef& input_sizes,
    int64_t dim,
    c10::optional<int64_t> start,
    c10::optional<int64_t> end,
    int64_t step) {
  auto start_val = start.has_value() ? start.value() : 0;
  auto end_val = end.has_value() ? end.value() : INT64_MAX;

  return slice_backward(grad, input_sizes, dim, start_val, end_val, step);
}

// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
          bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
  TORCH_CHECK(compute_uv,
           "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
           "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");

  auto m = self.size(-2);
  auto n = self.size(-1);
  auto k = sigma.size(-1);
  auto gsigma = grads[1];

  auto u = raw_u;
  auto v = raw_v;
  auto gu = grads[0];
  auto gv = grads[2];

  if (!some) {
    // We ignore the free subspace here because possible base vectors cancel
    // each other, e.g., both -v and +v are valid base for a dimension.
    // Don't assume behavior of any particular implementation of svd.
    u = raw_u.narrow(-1, 0, k);
    v = raw_v.narrow(-1, 0, k);
    if (gu.defined()) {
      gu = gu.narrow(-1, 0, k);
    }
    if (gv.defined()) {
      gv = gv.narrow(-1, 0, k);
    }
  }
  auto vh = v.mH();

  Tensor sigma_term;
  if (gsigma.defined()) {
    gsigma = gsigma.to(self.dtype());
    // computes u @ diag(gsigma) @ vh
    sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh);
  } else {
    sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }
  // in case that there are no gu and gv, we can avoid the series of kernel
  // calls below
  if (!gv.defined() && !gu.defined()) {
    return sigma_term;
  }

  auto uh = u.mH();
  auto sigma_inv = sigma.pow(-1);
  auto sigma_sq = sigma.pow(2);
  auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1);
  // The following two lines invert values of F, and fills the diagonal with 0s.
  // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
  // first to prevent nan from appearing in backward of this function.
  F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
  F = F.pow(-1);

  Tensor u_term, v_term;

  if (gu.defined()) {
    auto guh = gu.mH();
    u_term = at::matmul(u, F.mul(at::matmul(uh, gu) - at::matmul(guh, u)) * sigma.unsqueeze(-2));
    if (m > k) {
      // projection operator onto subspace orthogonal to span(U) defined as I - UU^H
      auto proj_on_ortho_u = -at::matmul(u, uh);
      proj_on_ortho_u.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).add_(1);
      u_term = u_term + proj_on_ortho_u.matmul(gu * sigma_inv.unsqueeze(-2));
    }
    u_term = at::matmul(u_term, vh);
  } else {
    u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }

  if (gv.defined()) {
    auto gvh = gv.mH();
    v_term = sigma.unsqueeze(-1) * at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh);
    if (n > k) {
      // projection operator onto subspace orthogonal to span(V) defined as I - VV^H
      auto proj_on_v_ortho = -at::matmul(v, vh);
      proj_on_v_ortho.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).add_(1);
      v_term = v_term + sigma_inv.unsqueeze(-1) * at::matmul(gvh, proj_on_v_ortho);
    }
    v_term = at::matmul(u, v_term);
  } else {
    v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }

  // for complex-valued input there is an additional term
  // https://giggleliu.github.io/2019/04/02/einsumbp.html
  // https://arxiv.org/abs/1909.02659
  if (self.is_complex() && gu.defined()) {
    Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1);
    at::real(L).zero_();
    at::imag(L).mul_(sigma_inv);
    Tensor imag_term = at::matmul(u * L.unsqueeze(-2), vh);
    return u_term + sigma_term + v_term + imag_term;
  }

  return u_term + sigma_term + v_term;
}

// The implementation follows:
// "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation"
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// However, the reference does not cover the constraints on eigenvectors to have 1-norm.
// See the details below.
Tensor eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
                    bool is_eigvec_tensor_nonempty, const Tensor& eigenvalues, const Tensor& eigenvectors) {
  TORCH_CHECK(is_eigvec_tensor_nonempty,
           "eig_backward: torch.eig(eigenvalues=False) is not differentiable. ",
           "Please use torch.linalg.eigvals");

  // variable names correspond to the ones in the reference document
  auto D = eigenvalues;
  const auto& U = eigenvectors;
  auto D_grad = grads[0];
  auto U_grad = grads[1];

  // The condition below is trying to marry torch.eig and torch.linalg.eig
  // for real inputs.
  //
  // For real inputs torch.eig returns a real 2D tensor representing real and complex
  // components of eigenvalues, while torch.linalg.eig will most likely always
  // return complex eigenvalues.
  if (!self.is_complex()) {
    auto is_imag_eigvals_zero = false;
    // path for torch.eig with always a "real" 2D tensor of eigenvalues
    if (!D.is_complex()) {
      // narrow extracts the column corresponding to the imaginary part
      is_imag_eigvals_zero = (D.narrow(-1, 1, 1) == 0.0).min().item<bool>();
    }
    // path for torch.linalg.eig with always a complex tensor of eigenvalues
    else {
      is_imag_eigvals_zero = (at::imag(D) == 0.0).min().item<bool>();
      // insert an additional dimension to be compatible with torch.eig.
      // Recall that it produces 2D tensors.
      // We extract only the real parts as there is no support for
      // complex eigenvalues with real inputs yet.
      D = at::real(D).unsqueeze(-1);
      D_grad = at::real(D_grad).unsqueeze(-1);
    }
    // No support for complex eigenvalues for real inputs yet.
    TORCH_CHECK(
      is_imag_eigvals_zero,
      "eig_backward: Backward calculation does not support complex eigenvalues for real inputs at the moment.");
  }
  else {
    // torch.eig returns 2d tensors for eigenvalues,
    // while torch.linalg.eig returns 1d.
    // Hence we insert additional dimension for complex input,
    // such that the same code could be used for both methods.
    // It will become unnecessary once torch.eig is deprecated.
    D = D.unsqueeze(-1);
    if (D_grad.defined()) {
      D_grad = D_grad.unsqueeze(-1);
    }
  }

  if (!D_grad.defined() && !U_grad.defined()) {
    return at::zeros_like(self, at::MemoryFormat::Contiguous);
  }

  // Adapting the result from the reference above for the complex input, we get:
  //
  // A_grad = U^{-H} (D_grad + F.conj() * (U^H U_grad)) U^H,
  // where M^H := (M.mT()).conj() and * is the Hadamard (element-wise) product.
  //
  // torch.eig/torch.linalg.eig produce eigenvectors which are
  // normalized to 1 norm, and the reference does not take that into account.
  // Hence, we have to modify the formula accordingly.
  //
  // Normalization to 1 norm imposes the following constraint on the eigenvectors, i.e.
  // (U^H U) * I = I, where I is an identity matrix.
  // Forward AD for this expression yields:
  // (dU^H U + U^H dU) * I = 0 => U^H dU * I = 0 <=> diag(U^H dU) = 0, which means
  // that each i-th column of U is orthogonal to the i-th column of dU.
  // Now, the value of dU which does not take this constraint into consideration
  // comes straight from the reference:
  // dU = U(F * U^{-1} dA U).
  // To make sure that U^H dU * I = 0, and using U^H U * I = I (normalization),
  // we propose a modifed forward AD for U:
  // dU_new = dU - U(U^H dU * I) (think of Gram-Schmidt)
  //
  // The rest is very similar to what is done in the reference and we finally arrive at:
  //
  // A_grad = U^{-H} (D_grad + (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj()) U^H
  //        = U^{-H} (eigenvalues_contribs + eigenvectors_contrib) U^H, where
  // eigenvalues_contribs := D_grad,
  // eigenvectors_contribs := (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj().
  // The contributions from the eigenvectors and the eigenvalues are computed below,
  // and then we solve the system
  // U^H A_grad = (eigenvalues_contribs + eigenvectors_contribs) U_H
  // to produce A_grad.

  // contribution from the eigenvectors
  Tensor U_contrib;
  if (U_grad.defined()) {
    // narrow extracts the column corresponding to the real part
    D = D.narrow(-1, 0, 1);
    auto F = (D.mT() - D);
    if (!F.is_complex()) {
      F.diagonal(0, -2, -1).fill_(INFINITY);
      F.pow_(-1);
    }
    else {
      // The F matrix construction for complex eigenvalues
      // if different from its real counterpart.
      // There is no complex INFINITY, and we cannot use
      //
      // F.pow_(-1);
      // F.diagonal(0, -2, -1).fill_(0);
      //
      // as it breaks gradgradcheck by double backward
      // propagating nans through F.pow_(-1) at zero,
      // the point of discontinuity.
      // Hence this hack below.
      F.diagonal(0, -2, -1).fill_(1);
      F.pow_(-1);
      F.diagonal(0, -2, -1).fill_(0);
    }
    auto U_grad_proj_onto_U = at::matmul(U.mH(), U_grad);
    auto Uh_U = at::matmul(U.mH(), U);
    U_contrib = (U_grad_proj_onto_U - Uh_U * U_grad_proj_onto_U.diagonal(0, -2, -1).unsqueeze(-2)) * F.conj();
  }
  else {
    U_contrib = at::zeros_like(self, at::MemoryFormat::Contiguous);
  }

  // contributions from the eigenvalues
  Tensor D_contrib;
  if (D_grad.defined()) {
    // narrow extracts the column corresponding to the real part
    D_contrib = D_grad.narrow(-1, 0, 1);
  }
  else {
    D_contrib = at::zeros_like(D, at::MemoryFormat::Contiguous);
  }

  return at::linalg_solve(U.mH(), at::matmul(U_contrib, U.mH()) + D_contrib * U.mH());
}

Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads,
                           const Tensor& self,
                           const Tensor& L,
                           const Tensor& V) {
  // https://arxiv.org/pdf/1701.00392.pdf Eq 4.77
  // For A = VLV^{-1}, denoting the gradients gA, gV and gL, we have
  // gA = V^{-H}(diag_embed(gL) + (V^H gV -V^HV diag(real(V^H gV))) / E*)V^H
  // Where:
  //   - E_ij = L_i - L_j if i != j
  //   - diag_embed takes a vector into a diagonal matrix
  //   - diag zeroes out elements outside of the diagonal
  //   - The division by E is done just outside of the diagonal. In the diagonal it is set to zero

  // Note: the term '-V^HV diag(real(V^H gV))' comes from the fact that the eigenvalue
  // decomposition is returned with eigenvectors normalized to have norm one.

  const auto gL = grads[0];
  const auto gV = grads[1];

  if (gV.defined()) {
    const auto Lconj = L.conj();
    auto Econj = Lconj.unsqueeze(-2) - Lconj.unsqueeze(-1);
    if (at::GradMode::is_enabled()) {
      // Avoids differentiating through at infinity when doing gradgrad
      // 1 could be any number, as we are going to overwrite the diagonal
      Econj.diagonal(0, -2, -1).fill_(1.);
    }

    const auto VhgV = at::matmul(V.mH(), gV);

    const auto diag_re_VhgV = at::real(VhgV).diagonal(0, -2, -1);
    auto result = VhgV - at::matmul(V.mH(), V * diag_re_VhgV.unsqueeze(-2));

    result.div_(Econj);

    // Copy gL into the diagonal
    if (gL.defined()) {
      result.diagonal(0, -2, -1).copy_(gL);
    }
    else {
      result.diagonal(0, -2, -1).zero_();
    }

    // Conjugate by V^{-H}
    result = at::linalg_solve(V.mH(), at::matmul(result, V.mH()));
    // If it is real, we have to project the derivative onto the real numbers
    return self.is_complex() ? result : at::real(result);
  }
  else {
    if (gL.defined()) {
      // Compute V^-H gL V^H
      const auto result = at::linalg_solve(V.mH(), gL.unsqueeze(-1) * V.mH());
      // If it is real, we have to project the derivative onto the real numbers
      return self.is_complex() ? result : at::real(result);
    } else {
      // If neither is defined, there's nothing to do
      return at::zeros_like(self, at::MemoryFormat::Contiguous);
    }
  }
}

// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf, page 10
// see also https://arxiv.org/pdf/1701.00392.pdf Eqs. (4.60) and (4.63)
std::tuple<Tensor, Tensor> linalg_eig_jvp(const Tensor& dA,
                                          const Tensor& L,
                                          const Tensor& V) {
  const auto dAComplex = dA.to(c10::toComplexType(dA.scalar_type()));
  const auto dVfactor = at::linalg_solve(V, at::matmul(dAComplex, V));
  const auto Lconj = L.conj();
  auto FTimesdVfactor = dVfactor / (Lconj.unsqueeze(-2) - Lconj.unsqueeze(-1));
  FTimesdVfactor.diagonal(0, -2, -1).zero_();

  return std::make_tuple(
    dVfactor.diagonal(0, -2, -1),
    at::matmul(V, FTimesdVfactor));
}

Tensor linalg_lstsq_jvp(
  const Tensor& A,
  const Tensor& B,
  const Tensor& dA,
  const Tensor& dB
) {
  auto pinvA = at::linalg_pinv(A);
  auto dpinvA = pinv_jvp(A, pinvA, dA);
  auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
  return dX;
}

std::tuple<Tensor, Tensor> linalg_lstsq_backward(
  const Tensor& grad,
  const Tensor& A,
  const Tensor& B,
  const c10::optional<double> rcond,
  const c10::optional<c10::string_view> driver,
  const std::array<bool, 2>& grad_input_mask
) {
  Tensor A_grad, B_grad;
  if (!grad.defined()) {
    return std::make_tuple(A_grad, B_grad);
  }

  auto A_requires_grad = grad_input_mask[0];
  auto B_requires_grad = grad_input_mask[1];

  Tensor pinvA;
  if (A_requires_grad) {
    pinvA = at::linalg_pinv(A);
    auto pinvA_grad = grad.matmul(B.transpose(-1, -2).conj());
    A_grad = pinv_backward(pinvA_grad, pinvA, A);
  }

  if (B_requires_grad) {
    if (!pinvA.defined()) {
      pinvA = at::linalg_pinv(A);
    }
    // Equivalent to
    // B_grad = std::get<0>(at::linalg_lstsq(A.transpose(-1, -2).conj(), grad, rcond, driver));
    // but we avoid this approach as `gelsy` is non-deterministic
    B_grad = pinvA.transpose(-1, -2).conj().matmul(grad);
  }

  return std::make_tuple(A_grad, B_grad);
}


// jvp functions for eigenvalues and eigenvectors are separate
// because currently forward AD only works with one rule per output
Tensor eigh_jvp_eigenvalues(
    const Tensor& input_tangent,
    const Tensor& eigenvalues,
    const Tensor& eigenvectors) {
  // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
  // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
  // Section 3.1 Eigenvalues and eigenvectors

  // TODO: gradcheck from test_ops.py hangs with complex inputs
  TORCH_CHECK_NOT_IMPLEMENTED(
      !input_tangent.is_complex(),
      "the derivative for 'eigh' with complex inputs is not implemented.");

  // see the note in the implementation of eigh_backward that tangent should be Hermitian
  auto hermitian_tangent = 0.5*(input_tangent + input_tangent.mH());

  auto tmp = at::matmul(at::matmul(eigenvectors.mH(), hermitian_tangent), eigenvectors);
  auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
  if (eigenvalues_tangent.is_complex()) {
    return at::real(eigenvalues_tangent);
  }
  return eigenvalues_tangent;
}

Tensor eigh_jvp_eigenvectors(
    const Tensor& input_tangent,
    const Tensor& eigenvalues,
    const Tensor& eigenvectors) {
  // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
  // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
  // Section 3.1 Eigenvalues and eigenvectors

  TORCH_CHECK_NOT_IMPLEMENTED(
      !input_tangent.is_complex(),
      "the derivative for 'eigh' with complex inputs is not implemented.");

  auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1);
  E.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);

  // see the note in the implementation of eigh_backward that tangent should be Hermitian
  auto hermitian_tangent = 0.5*(input_tangent + input_tangent.mH());

  auto tmp = at::matmul(at::matmul(eigenvectors.mH(), hermitian_tangent), eigenvectors);
  return at::matmul(eigenvectors, tmp.div(E));
}

Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
                     bool eigenvectors, const Tensor& L, const Tensor& V) {
  // This function is used for both torch.symeig and torch.linalg.eigh.
  // eigh (and torch.symeig) operates only on symmetric (resp. Hermitian) inputs.

  // [Note: eigh backward]
  // General considerations of the differential and adjoint
  // Let U(n) = {U \in C^{n x n} | U^H U = I} by the unitary group and
  // Her(n) = {A \in C^{n x n} | A^H = A} be the Hermitian matrices
  // eigh : Her(n) -> U(n) x R^n
  // Denoting the tangent spaces as T, the differential of eigh at A = VLV^H
  // (i.e. forward differentiation) is a linear map
  // (d eigh)_A : T_A Her(n) -> T_V U(n) x T_L R^n
  // R^n is a linear space, so it is canonically isomorphic to its tangent space
  // Since X, Y \in Her(n) => X + Y \in Her(n), Her(n) is also linear. For this reason, we can write
  // (d eigh)_A : Her(n) -> T_V U(n) x R^n
  // Differentiating the equation U^H U = I, the tangent space of U(n) is given by
  // T_V U(n) = {X \in C^{n x n} | X^H V = -V^H X}. That is, matrices such that V^HX is skew-Hermitian.
  // We then have that the adjoint of the differential (i.e. reverse differentiation) is a map
  // (d eigh)*_A : T_V U(n) x Her(n) -> Her(n)
  // Since the adjoint is defined on T_V U(n), we need to project the input gradient onto T_V U(n)

  // Orthogonal projection \pi_V : C^{n x n} -> T_V U(n)
  // We have that an element gV \in T_V U(n) can be represented as gV = VX for a skew-Hermitian
  // matrix X := V^H gV.
  // Using that V \in U(n) is an isometry of C^{n x n}, we have that
  // \pi_V(gV) := \pi_V(VX) = V\pi_I(X) = V\pi_I(V^H gV)
  // pi_I (X) = (X - X^H) / 2 is the orthogonal projection from C^{n x n} into the skew-Hermitian matrices

  // The formula
  // Following the derivation in
  // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf (Sec 3.1)
  // For A = VLV^H, with V with unitary and L real,
  // denoting the gradients gA \in Her(n), gV \in C^{n x n} and gL \in R^n, we have
  // gA = (d eigh)*_A(\pi_V(gV), gL)
  //    = V(diag_embed(gL) + \pi_I(V^H gV) / E)V^H
  // where:
  //   - E_ij = L_i - L_j if i != j
  //   - diag_embed takes a vector into a diagonal matrix
  //   - The division by E is done just outside of the diagonal. In the diagonal it is set to zero

  // This check just can be triggered in the backwards of torch.symeig
  TORCH_CHECK(eigenvectors,
           "eigh_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ",
           "Use torch.linalg.eigvalsh(A) instead.");

  const auto gL = grads[0];
  const auto gV = grads[1];

  if (gV.defined()) {
    auto E = L.unsqueeze(-2) - L.unsqueeze(-1);
    if (at::GradMode::is_enabled()) {
      // Avoids differentiating through at infinity when doing gradgrad
      // 1 could be any number, as we are going to overwrite the diagonal
      E.diagonal(0, -2, -1).fill_(1);
    }

    Tensor result =  at::matmul(V.mH(), gV);
    // Project
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
    result = result.sub(result.mH()).mul_(0.5);
    // E is skew-symmetric. Multiplying entrywise a skew-Hermitian matrix by a
    // skew-symmetric matrix gives a Hermitian matrix, as we expected.
    result.div_(E);

    if (gL.defined()) {
      result.diagonal(0, -2, -1).copy_(gL);
    }
    else {
      result.diagonal(0, -2, -1).zero_();
    }

    // Conjugating a Hermitian matrix by a unitary matrix gives a Hermitian matrix
    return at::matmul(V, at::matmul(result, V.mH()));
  }
  else {
    if (gL.defined()) {
      // If we just gL is defined, one matmul suffices
      return at::matmul(V * gL.unsqueeze(-2), V.mH());
    } else {
      // If neither is defined, there's nothing to do
      return at::zeros_like(self, at::MemoryFormat::Contiguous);
    }
  }
}

std::tuple<Tensor, Tensor> linalg_qr_jvp(
  const Tensor& dA,
  const Tensor& Q,
  const Tensor& R
) {
  auto m = dA.size(-2);
  auto n = dA.size(-1);
  auto k = std::min(m, n);

  auto dA1 = dA.narrow(-1, 0, k);
  auto R1 = R.narrow(-1, 0, k);

  // dB1 = Q^H dA1 R1^{-1}
  auto dB1 = at::linalg_solve_triangular(R1, Q.mH().matmul(dA1), /*upper=*/true, /*left=*/false);

  // dC1 = (dB1 + dB1^H).triu(-1) + (dB1 + dB1^H) * 0.5 I
  auto dC1 = (dB1 + dB1.mH()).triu();
  dC1.diagonal(0, -2, -1).mul_(0.5);

  auto dR1 = dC1.matmul(R1);

  // dQ = (dA1 - Q dR1) R1^{-1}
  auto dQ = at::linalg_solve_triangular(R1, dA1 - Q.matmul(dR1), /*upper=*/true, /*left=*/false);

  Tensor dR;
  if (m >= n) {
    dR = dR1;
  }
  else {
    auto dA2 = dA.narrow(-1, k, n - k);
    auto R2 = R.narrow(-1, k, n - k);
    auto dR2 = Q.mH().matmul(dA2 - dQ.matmul(R2));
    dR = at::cat({dR1, dR2}, -1);
  }

  return std::make_tuple(dQ, dR);
}

Tensor linalg_qr_jvp_Q(
  const Tensor& dA,
  const Tensor& Q,
  const Tensor& R
) {
  return std::get<0>(linalg_qr_jvp(dA, Q, R));
}

Tensor linalg_qr_jvp_R(
  const Tensor& dA,
  const Tensor& Q,
  const Tensor& R
) {
  return std::get<1>(linalg_qr_jvp(dA, Q, R));
}

Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
                          c10::string_view mode, const Tensor& q, const Tensor& r){
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  bool compute_q, reduced;
  std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
  TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. "
                         "Please use torch.linalg.qr(..., mode='reduced')");

  auto square_deep_case_backward = [](const Tensor& grad_Q,
                                      const Tensor& grad_R,
                                      const Tensor& A,
                                      const Tensor& Q,
                                      const Tensor& R) -> Tensor {
    // For square and deep (tall) case we refer:
    // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra.
    // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition)
    // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks.
    // https://arxiv.org/abs/1903.09650 Section 3. QR factorization
    // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html

    // Compute R grad_R^H
    Tensor R_term;
    if (grad_R.defined()) {
      R_term = at::matmul(R, grad_R.mH());
    } else {
      // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
      R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }

    // Compute grad_Q^H Q
    Tensor Q_term;
    if (grad_Q.defined()) {
      Q_term = at::matmul(grad_Q.mH(), Q);
    } else {
      // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
      Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }

    Tensor M = R_term - Q_term;

    // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity
    Tensor M_tril = at::tril(M);
    M = M_tril + M_tril.mH();
    M.diagonal(0, -2, -1).mul_(0.5);

    Tensor rhs_term;
    if (grad_Q.defined()) {
      rhs_term = grad_Q + at::matmul(Q, M);
    } else {
      rhs_term = at::matmul(Q, M);
    }

    // Compute rhs_term @ R^{-H}
    Tensor grad_A = at::linalg_solve_triangular(
        R.transpose(-2, -1).conj(),
        rhs_term,
        /*upper=*/false,
        /*left=*/false,
        /*unitriangular=*/false);

    return grad_A;
  };

  auto m = self.size(-2);
  auto n = self.size(-1);

  TORCH_CHECK(
      ((m <= n && (!reduced)) || reduced),
      "The derivative of qr is not implemented when mode='complete' and nrows > ncols.");

  auto grad_Q = grads[0];
  auto grad_R = grads[1];

 if (m >= n) {
    return square_deep_case_backward(grad_Q, grad_R, self, q, r);
  } else {
    // For wide (m < n) input matrices A,  partition A = [X|Y] and R = [U|V]
    // X and U are square full rank matrices. We will partition grads,
    // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
    // To obtain grad_X we reuse the gradient formula from the square case.
    // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
    // where grad_Q_prime = grad_Q + Y @ grad_V^H
    // and grad_Y = Q @ grad_V.
    // Then concatenate grads to get grad_A = [grad_X | grad_Y].

    auto Y = self.narrow(-1, m, n - m);
    auto U = r.narrow(-1, 0, m);
    Tensor grad_Y, grad_X, grad_V, grad_Q_prime;

    if (grad_R.defined()) {
      grad_V = grad_R.narrow(-1, m, n - m);
      // reuse grad_R to store grad_U
      grad_R = grad_R.narrow(-1, 0, m);
      // grad_Q_prime starts with the value of Y @ grad_V^H
      grad_Q_prime = at::matmul(Y, grad_V.mH());
    } else {
      // when grad_R is not defined then grad_V and grad_Q_prime
      // get initialized with zeros
      grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
      grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    }

    if (grad_Q.defined()) {
      // add the grad_Q term into grad_Q_prime when defined o/w is 0
      grad_Q_prime = grad_Q_prime + grad_Q;
    }
    // Calculate grad_X using the helper. Grad_R contains the grad_U value
    grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U);
    grad_Y = at::matmul(q, grad_V);
    // Concatenate grad_X and grad_Y to get grad_A.
    return at::cat({grad_X, grad_Y}, -1);
  }
}

// Based on:
//
// Mathias, Roy.
// A Chain Rule for Matrix Functions and Applications.
// SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.

template <typename func_t>
Tensor differential_analytic_matrix_function(
    const Tensor& self, const Tensor& grad,
    const func_t& matrix_function,
    const bool adjoint // Choose between forward (adjoint=false) or backward AD (adjoint=true)
  ) {
  // Given an analytic matrix function, this computes the differential (forward AD)
  // or the adjoint of the differential (backward AD)
  auto A = adjoint ? self.transpose(-2, -1).conj() : self;
  auto meta_grad_sizes = A.sizes().vec();
  meta_grad_sizes[A.dim() - 2] *= 2;
  meta_grad_sizes[A.dim() - 1] *= 2;

  auto n = A.size(-1);
  auto meta_grad = at::zeros(meta_grad_sizes, grad.options());
  meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(A);
  meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(A);
  meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);

  return matrix_function(meta_grad).narrow(-2, 0, n).narrow(-1, n, n);
}

Tensor linalg_matrix_exp_differential(const Tensor& self, const Tensor& grad, bool adjoint) {
  at::NoTF32Guard disable_tf32;

  return differential_analytic_matrix_function(self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint);
}

Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
  if (self.numel() == 0) {
    return at::empty_like(self);
  }

  auto det_backward_nonsingular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
    // Derived from Jacobi's formula for partial derivative, which can be found
    // at https://en.wikipedia.org/wiki/Jacobi%27s_formula
    // i.e. if A is the input matrix, then
    // A_grad = A^{-H} (grad * det.conj()) I, where
    // A^{-H} = (A^{-1}).T.conj()

    // create a matrix d := (grad * det.conj())  I
    auto d = at::zeros_like(self);
    d.diagonal(0, -2, -1).copy_((grad * det.conj()).unsqueeze(-1));
    return at::linalg_solve(self.mH(), d);
  };

  auto det_backward_singular = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor {
    // Derived from the gradient formula that would be used if `self`'s
    // determinant is calculated using SVD, like so:
    //    u, s, vh = svd(self)
    //    det(self) = det(u) * prod(s) * det(vh)
    //
    // This formula should be correct even if `self` is nonsingular.
    Tensor u, s, vh;
    std::tie(u, s, vh) = at::linalg_svd(self);
    auto u_det = at::linalg_det(u);
    auto s_prod = s.prod(-1);
    auto vh_det = at::linalg_det(vh);

    auto u_det_grad = grad * (vh_det * s_prod).conj();
    auto u_grad = det_backward_nonsingular(u_det_grad, u, u_det);

    auto s_prod_grad = handle_r_to_c(s_prod.scalar_type(), grad * (u_det * vh_det).conj());
    auto s_grad = prod_backward(s_prod_grad, s, s_prod, -1, false);

    auto vh_det_grad = grad * (u_det * s_prod).conj();
    auto vh_grad = det_backward_nonsingular(vh_det_grad, vh, vh_det);
    auto v = vh.mH();
    auto v_grad = vh_grad.mH();

    // svd_backward is written for a function
    // svd: self -> (U, S, V), which is different
    // from torch.linalg.svd which is a map self -> (U, S, Vh), where
    // Vh = V.mH()
    return svd_backward({u_grad, s_grad, v_grad}, self, true, true, u, s, v);
  };

  auto eps = at::native::_get_epsilon(c10::toValueType(self.scalar_type()));
  auto singular_det_cutoff = eps * at::linalg_matrix_norm(self);

  if (self.dim() == 2) {
    if (det.abs().lt(singular_det_cutoff).item<bool>()) {
      return det_backward_singular(grad, self, det);
    } else {
      return det_backward_nonsingular(grad, self, det);
    }
  } else {
    auto nonzero_det_mask = det.abs().ge(singular_det_cutoff);
    if (nonzero_det_mask.all().item<bool>()) {
      return det_backward_nonsingular(grad, self, det);
    }

    auto zero_det_mask = nonzero_det_mask.logical_not();
    if (zero_det_mask.all().item<bool>()) {
      return det_backward_singular(grad, self, det);
    }

    Tensor self_grad = self.new_empty(self.sizes(), self.options());

    auto nonzero_det_list = at::native::toListOfOptionalTensors(nonzero_det_mask);
    self_grad.index_put_(
      /*indices=*/nonzero_det_list,
      // NOLINTNEXTLINE(bugprone-argument-comment)
      /*value=*/det_backward_nonsingular(
        grad.index(nonzero_det_list),
        self.index(nonzero_det_list),
        det.index(nonzero_det_list)));

    auto zero_det_list = at::native::toListOfOptionalTensors(zero_det_mask);
    self_grad.index_put_(
      /*indices=*/zero_det_list,
      // NOLINTNEXTLINE(bugprone-argument-comment)
      /*value=*/det_backward_singular(
        grad.index(zero_det_list),
        self.index(zero_det_list),
        det.index(zero_det_list)));

    return self_grad;
  }
}

// The backward for this function is just a specialized version of
// lu.backward, which is implemented in /torch/_autograd_functions.py
Tensor _det_lu_based_helper_backward(
  const Tensor& det_grad,
  const Tensor& det,
  const Tensor& self,
  const Tensor& lu,
  const Tensor& pivs
) {
  if (!self.numel()) {
    return at::zeros_like(self, at::MemoryFormat::Contiguous);
  }
  if (!det_grad.defined()) {
    return Tensor();
  }


  // run det_backward only if backward is run on _det_lu_based_helper_backward.
  // _det_lu_based_helper_backward is more stable for forward det computing functions,
  // but it fails with double backward gradient checks (gradgradcheck).
  // det_backward, on the other hand, is less stable (due to restrictions on svd_backward,
  // namely, svd_backward requries distinct singular values which are sufficiently different
  // from each other), yet, if its computation is stable, so is its double backward.
  // Hence, if only single backward is run, we use _det_lu_based_helper_backward,
  // for the double backward case we use det_backward. The latter approach could produce
  // unstable gradients, therefore we DO NOT recommend double backpropagation through
  // det computing functions.
  if (at::GradMode::is_enabled()) {
    return det_backward(det_grad, self, det);
  }

  // we use a sequence of kernels to avoid memory copies and checks,
  // as in the implementation of this function we are 100% sure that
  // `lu` and `pivs` are coming from a LAPACK routine.
  return at::_det_lu_based_helper_backward_helper(det_grad, det, self, lu, pivs);
}

Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
  auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
    Tensor u, sigma, vh;
    std::tie(u, sigma, vh) = at::linalg_svd(self, false);
    Tensor v = vh.mH();
    // logdet = \sum log(sigma)
    auto gsigma = grad.unsqueeze(-1).div(sigma);
    return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
  };

  auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor {
    return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().mT();
  };

  if (self.dim() == 2) {
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
    if (logdet.item<double>() != -INFINITY) {
      return nonsingular_case_backward(grad, self);
    } else {
      return singular_case_backward(grad, self);
    }
  } else {
    auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY));
    c10::optional<Tensor> first_finite_logdet_index = finite_logdet_indices[0];

    if (first_finite_logdet_index->size(0) == logdet.numel()) {  // all log determinants are finite (non-singular)
      return nonsingular_case_backward(grad, self);
    }

    auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY));
    c10::optional<Tensor> first_neginf_logdet_index = neginf_logdet_indices[0];

    if (first_neginf_logdet_index->size(0) == logdet.numel()) {  // all log determinants are -inf (singular)
      return singular_case_backward(grad, self);
    }

    Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

    // invertible case
    grad_logdet.index_put_(/*indices=*/finite_logdet_indices,
                           // NOLINTNEXTLINE(bugprone-argument-comment)
                           /*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices),
                                                               self.index(finite_logdet_indices)));

    // non-invertible case, uses SVD
    grad_logdet.index_put_(/*indices=*/neginf_logdet_indices,
                           // NOLINTNEXTLINE(bugprone-argument-comment)
                           /*value=*/singular_case_backward(grad.index(neginf_logdet_indices),
                                                            self.index(neginf_logdet_indices)));

    return grad_logdet;
  }
}

Tensor slogdet_backward(const Tensor& grad_logabsdet,
                        const Tensor& self,
                        const Tensor& signdet, const Tensor& logabsdet) {
  auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
    Tensor u, sigma, vh;
    std::tie(u, sigma, vh) = at::linalg_svd(self, false);
    Tensor v = vh.mH();
    // sigma has all non-negative entries (also with at least one zero entry)
    // so logabsdet = \sum log(abs(sigma))
    // but det = 0, so backward logabsdet = \sum log(sigma)
    auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
    return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
  };

  auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
    // TODO: replace self.inverse with linalg_inverse
    return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().mH();
  };

  if (self.dim() == 2) {
    bool is_singular = self.is_complex() ? signdet.abs().item<double>() == 0 : signdet.item<double>() == 0;
    if (is_singular) {
      return singular_case_backward(grad_logabsdet, self);
    } else {
      return nonsingular_case_backward(grad_logabsdet, self);
    }
  } else {
    auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(self.is_complex() ? at::where(signdet.abs()) : at::where(signdet));
    c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0];

    if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) {  // all log determinants are finite (non-singular)
      return nonsingular_case_backward(grad_logabsdet, self);
    }

    auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0));
    c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0];

    if (first_zero_signdet_index->size(0) == logabsdet.numel()) {  // all log determinants are -inf (singular)
      return singular_case_backward(grad_logabsdet, self);
    }

    Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

    // invertible case
    grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices,
                            // NOLINTNEXTLINE(bugprone-argument-comment)
                            /*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices),
                                                                self.index(nonzero_signdet_indices)));

    // non-invertible case, uses SVD
    grad_slogdet.index_put_(/*indices=*/zero_signdet_indices,
                            // NOLINTNEXTLINE(bugprone-argument-comment)
                            /*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices),
                                                             self.index(zero_signdet_indices)));

    return grad_slogdet;
  }
}

// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> triangular_solve_backward(
    const Tensor & grad_x, const Tensor & grad_m,
    const Tensor & b, const Tensor & a, const Tensor & x,
    const bool upper, const bool transpose, const bool unitriangular,
    std::array<bool, 2> output_mask) {
  Tensor grad_b, grad_a;
  if (grad_x.defined() || grad_m.defined()) {
    if (grad_x.defined()) {
      grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
      if (output_mask[1]) {
        grad_a = transpose ? -x.conj().matmul(grad_b.mT()) : -grad_b.matmul(x.mH());
        if (upper) {
          grad_a = grad_a.triu((int) unitriangular);
        } else {
          grad_a = grad_a.tril(-((int) unitriangular));
        }
      }
    }
    if (!grad_a.defined()) {
      grad_a = at::zeros({1}, a.options()).expand_as(a);
    }
    if (!grad_b.defined()) {
      grad_b = at::zeros({1}, b.options()).expand_as(b);
    }
    if (output_mask[1] && grad_m.defined()) {
      grad_a = grad_a.add(grad_m);
    }
  }
  return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}

Tensor triangular_solve_jvp(
  const Tensor& X, const Tensor& A,
  const Tensor& dA, const Tensor& dB,
  const bool upper,
  const bool transpose,
  const bool unitriangular
) {
  return generic_solve_jvp(
    [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
      return std::get<0>(at::triangular_solve(dB - dA_contrib, A, upper, transpose, unitriangular));
    },
    X, A, dA, dB
  );
}

Tensor linalg_solve_triangular_forward_AD(
    const Tensor& A_t,
    const Tensor& B_t,
    const Tensor& A,
    const Tensor& X,
    const bool upper,
    const bool left,
    const bool unitriangular) {
  // The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
  // For the derivation see:
  // [Note: Forward / Backward AD solve_triangular]
  const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular))
                                : A_t.tril(- static_cast<int>(unitriangular));
  const Tensor X_t = B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t));
  return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular);
}

std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
    const Tensor& grad,
    const Tensor& A,
    const Tensor& X,
    const bool upper,
    const bool left,
    const bool unitriangular,
    std::array<bool, 2> output_mask) {
  const bool A_requires_grad = output_mask[0];
  const bool B_requires_grad = output_mask[1];
  // [Note: Forward / Backward AD solve_triangular]
  // Assume left=true for simplicity.
  // Remark: A solver computes A^{-1}B
  //
  // Forward AD:
  // If f(A) = A^{-1}, differentiating the equation A^{-1}A = I_n gives
  // (df)_A(E) = -A^{-1}EA^{-1}
  // As such, if g(A,B) = A^{-1}B,
  // (dg)_(A,B)(E_A, E_B) = -A^{-1}E_AA^{-1}B + A^{-1}E_B
  //                      = A^{-1}(E_B - E_AX)

  // Backward AD:
  // Denoting the gradients by G_A, G_B, we solve above to give
  // G_B = A^{-H}G_X
  // G_A = -A^{-H}G_XX^H = -G_B X^H
  //
  // Note that you don't need to store B for forward nor backward
  //
  // These formulas work for a general solver of linear equations.
  // Let's prove now that when A is triangular, G_A is the projection onto the triangular matrices
  // of the formula above, i.e. simply taking triu (resp. tril) in the formula above.
  // This is because, since the triangular matrices form a vector space, the tangent space at any
  // point is itself the space of triangular matrices. The result follows from a reasoning as that
  // at the end of [Note: eigh backward]
  // Something similar happens for `unitriangular`, only that int his case the tangent space is
  // the set of lower-triangular matrices with zeros on the diagonal.

  if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
      return std::make_tuple(Tensor{}, Tensor{});
  }
  // We always need to comput G_B
  const Tensor A_H = A.mH();
  const Tensor G_B = at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);

  if (A_requires_grad) {
    const Tensor X_H = X.mH();
    Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B);
    G_A = upper ? G_A.triu(static_cast<int>(unitriangular))
                : G_A.tril(- static_cast<int>(unitriangular));
    return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{});
  } else {
    return std::make_tuple(Tensor{}, G_B);
  }
}

std::tuple<Tensor, Tensor> cholesky_solve_backward(
    const Tensor& grad_x, const Tensor& self,
    const Tensor& input2, const Tensor& result, const bool upper) {
  Tensor grad_self, grad_input2;
  if (grad_x.defined()) {
    grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);

    Tensor common_term = at::matmul(grad_self, result.mH());
    common_term = common_term + common_term.mH();

    if (upper) {
      grad_input2 = -at::matmul(input2, common_term);
    } else {
      grad_input2 = -at::matmul(common_term, input2);
    }
  }
  return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
}

Tensor cholesky_solve_jvp(
  const Tensor& X,
  const Tensor& U,
  const Tensor& dU,
  const Tensor& dB,
  const bool upper
) {
  auto dK = upper ? dU.mH().matmul(U)
                  : dU.matmul(U.mH());
  auto dA = dK + dK.mH();
  return generic_solve_jvp(
    [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
      return at::cholesky_solve(dB - dA_contrib, A, upper);
    },
    X, /*A=*/U, dA, dB
  );
}

Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) {
  // Forward is C2R (onesided)
  // Think of onesided C2R irfft as
  //    1. fill the other half by conjugate symmetry
  //    2. inverse C2C ifft
  //    3. discard the complex dimension
  // So backward is
  //    1. R2C rfft (essentially add dummy complex dimension, and dft)
  //    2. accumulate gradient by conjugate symmetry
  //       since rfft results follow conjugate symmetry, we only need to
  //       double some entries from onesided rfft results, i.e., the ones with
  //       their reflected indices also landing out of the onesided range. So
  //       consider the index of last dim:
  //           i.   idx = 0.
  //                Reflected to (N - 0) % N = 0. Not doubled.
  //           ii   0 < idx < floor(N/2) (last).
  //                N > N - idx > ceil(N/2)
  //                Reflected to ()
  //           iii. idx = floor(N/2) = N/2 (last) when N even.
  //                Reflected to (N - N/2) % N = N/2. Not doubled.
  //           iv.  idx = floor(N/2) = (N-1)/2 (last) when N odd.
  //                Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
  //       Therefore, needs to double
  //           idx = 1, 2, ..., N/2 - 1     when N even
  //           idx = 1, 2, ..., (N-1)/2     when N odd
  //       that is
  //           idx = 1, 2, ..., N - (floor(N/2) + 1)
  //               = 1, 2, ..., N - onesided_length
  auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true);

  auto double_length = grad.size(dim.back()) - gI.size(dim.back());
  if (double_length > 0) {  // also covers case when signal size is zero
    gI.narrow(dim.back(), 1, double_length).mul_(2);
  }
  return gI;
}

Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization,
                        bool onesided, int64_t last_dim_size) {
  if (!onesided) {
    return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
  }

  // Forward is R2C (onesided)
  // Think of onesided R2C rfft as
  //     1. view as complex numbers (fill complex dim with zeros)
  //     2. C2C fft
  //     3. discard half of results
  // So backward is
  //     1. fill the other half with zeros (with `zero_grad_shape` below)
  //        (C2C ifft only take twosided inputs so we need to fill here)
  //     2. inverse C2C ifft
  //     3. discard the complex dim
  auto half_sizes = grad.sizes();
  at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end());
  const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size());
  new_grad_shape[last_dim] = last_dim_size;

  const auto zero_length = last_dim_size - grad.size(dim.back());
  auto complex_full_grad = zero_length > 0 ? at::zeros(new_grad_shape, grad.options()) : grad;
  if (zero_length > 0) {
    complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad);
  }
  return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
}

// Helper for batchnorm_double_backward
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) {
  auto r = to_sum.sum(0, keepdim);
  int64_t start_point_exclusive = keepdim ? 1 : 0;
  for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
    r = r.sum(dim, keepdim);
  }
  return r;
}

// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
  auto src_expanded = src;
  while (src_expanded.sizes().size() < target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(1);
  }
  if (src_expanded.sizes().size() == target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(0);
  }
  return src_expanded;
}

// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
  auto src_expanded = src;
  while (src_expanded.sizes().size() < target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(1);
  }
  return src_expanded.expand_as(target);
}

std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
    const Tensor & input,
    const c10::optional<Tensor> & gamma,
    const Tensor & ggI,
    const Tensor & ggG,
    const Tensor & ggB,
    const Tensor & gO,
    const c10::optional<Tensor> & running_mean,
    const c10::optional<Tensor> & running_var,
    bool training,
    double eps,
    const c10::optional<Tensor> & save_mean,
    const c10::optional<Tensor> & save_invstd,
    std::array<bool,3> output_mask) {

  bool affine = isDefined(gamma);
  // TODO: Do we have a ScalarOrTensor type?  Would such a thing exist?
  Tensor gamma_expanded;
  Tensor ggG_expanded, ggB_expanded;
  if (affine) {
    gamma_expanded = expand_as_dim1(*gamma, input);
    if (ggG.defined()) {
      ggG_expanded = expand_as_dim1(ggG, input);
    }
    if (ggB.defined()) {
      ggB_expanded = expand_as_dim1(ggB, input);
    }
  } else {
    gamma_expanded = at::ones({}, input.options());
  }

  // define some terms we will reuse
  auto M = input.size(0);
  for (auto s : input.sizes().slice(2)) {
    M *= s;
  }
  // for half inputs, save_mean, save_invstd are float (ideally, we would cast
  // everything else, but not now)
  auto mu = unsqueeze_dim1(training ? toNonOptTensor(save_mean).to(input.scalar_type()) : toNonOptTensor(running_mean), input);
  auto input_sub_mu = input - mu;
  auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
      training ? toNonOptTensor(save_invstd).to(input.scalar_type())
               : toNonOptTensor(running_var).add(Scalar(eps)).pow(-0.5),
      input);
  auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
  auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

  // calculate gI
  auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
  auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
  auto gO_sum = sum_exclude_dim1(gO);

  Tensor gI;
  if (ggI.defined() && training) {
    auto ggI_sum = sum_exclude_dim1(ggI);
    auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
    auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_(
                    (sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
    auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
    auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
    auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
    gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
  }

  // add contribution of gamma term to gI
  Tensor gI_G_term;
  if (affine && ggG.defined()) {
    if (training) {
      auto t0 = gO * sigma2_eps_neg_1_2;
      auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
      auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M);
      gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
      gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
    } else {
      gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
      gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
    }
  }

  // this is the first backward's grad_input
  auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor {
    auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
    auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_(
                input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu));
    return h0 * h1;
  };

  // calculate gG
  Tensor gG;
  if (affine && ggI.defined()) {
    if (training) {
      // gG is just the first backwards with the gamma term removed (then shaped properly)
      gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
      gG = sum_exclude_dim1(gG, false);
    } else {
      gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
    }
  }

  // calculate ggO
  Tensor ggO;
  // contribution of input term
  if (ggI.defined()) {
    if (training) {
      ggO = first_back_grad_input(ggI, gamma_expanded);
    } else {
      ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
    }
  }
  if (ggG.defined()) {
    auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
    ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
  }
  if (ggB.defined()) {
    auto ggO_B_term = ggB_expanded;
    ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
  }

  if (output_mask[1] && !gG.defined()) {
    AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
  }

  return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};

}

std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
    const Tensor& input_t,
    const c10::optional<Tensor>& gamma,
    const Tensor& ggI,
    const Tensor& ggG,
    const Tensor& ggB,
    const Tensor& gO_t,
    const Tensor& save_mean_t,
    const Tensor& save_invstd_t,
    IntArrayRef normalized_shape,
    std::array<bool, 3> output_mask) {

  const int normalized_ndim = normalized_shape.size();
  const auto input_shape = input_t.sizes();
  const auto input_ndim = input_t.dim();
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
  const int axis = input_ndim - normalized_ndim;
  const int64_t M =
      c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
  const int64_t N =
      c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
  //printf("M: %ld, N: %ld", M, N);

  auto input = input_t.reshape({M, N});
  auto gO = gO_t.reshape({M, N});
  auto save_mean = save_mean_t.reshape({M, 1});
  auto save_invstd = save_invstd_t.reshape({M, 1});

  bool affine = isDefined(gamma);
  Tensor gamma_expanded;
  Tensor ggG_expanded, ggB_expanded;
  if (affine) {
    gamma_expanded = gamma->reshape({1, N});
    if (ggG.defined()) {
      ggG_expanded = ggG.reshape({1, N});
    }
    if (ggB.defined()) {
      ggB_expanded = ggB.reshape({1, N});
    }
  } else {
    gamma_expanded = at::ones({1}, input.options());
  }

  Tensor ggI_expanded;
  if (ggI.defined()) {
    ggI_expanded = ggI.reshape({M, N});
  }

  // for half inputs, save_mean, save_invstd are float
  // (ideally, we would cast everything else, but not now)
  auto mu = save_mean.to(input.scalar_type());
  auto input_sub_mu = input - mu;
  auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
  auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
  auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

  Tensor gI;
  // calculate gI
  auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;

  if (ggI.defined()) {

    auto gxhat = gO * gamma_expanded;
    auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
    auto gxhat_sum = gxhat.sum(1, true);

    auto ggI_sum = ggI_expanded.sum(1, true);
    auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);

    auto all_sub = ((ggI_sum * gxhat_sum).div_(N)).sub_((ggI_expanded * gxhat).sum(1, true)).add_(
                    (sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N));
    auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
    auto gI_1t = (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
    auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) * (ggI_sum.div(N) - ggI_expanded);

    gI = (gI_0t.add_(gI_1t).add_(gI_2t));
  }

  // add contribution of gamma term to gI
  if (affine && ggG.defined()) {
    auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
    auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
    auto t2 = (input_mu_sigma2_neg_3_2 * (gO * ggG_expanded * input_sub_mu).sum(1,true)).div_(-N);
    auto gI_G_term = t0.add_(t1).add_(t2);
    gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
  }


  if (gI.defined()) {
    //printf("=== computing gI\n");
    gI = gI.reshape_as(input_t);
  }

  // this is the grad_input for the first backward function
  auto first_bwd_fn_grad_input = [&](const Tensor& gO_local, const Tensor& gamma_local) -> Tensor {
    auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
    auto h1 = (N * gO_local).sub_(gO_local.sum(1,true)).sub_(
                input_sub_mu.mul(sigma2_eps_neg_1) * (gO_local * input_sub_mu).sum(1,true));
    return h0 * h1;
  };

  // calculate gG
  Tensor gG;
  if (affine && ggI.defined()) {
    gG = first_bwd_fn_grad_input(ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
    gG = (gO * gG).sum(0);
    gG = gG.reshape_as(*gamma);
  }

  // calculate ggO
  Tensor ggO;
  // contribution of input term
  if (ggI.defined()) {
    ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
  }
  if (ggG.defined()) {
    auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
    ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
  }
  if (ggB.defined()) {
    auto ggO_B_term = ggB_expanded;
    ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
  }
  if (ggO.defined()) {
    ggO = ggO.expand({M, N}).reshape_as(input_t);
  }

  if (output_mask[1] && !gG.defined()) {
    AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
  }

  return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}

std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
    const Tensor& dY,
    const Tensor& dmean,
    const Tensor& drstd,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const c10::optional<Tensor>& gamma,
    int64_t N,
    int64_t C,
    int64_t HxW,
    int64_t group,
    double eps,
    std::array<bool, 3> grad_input_mask) {
  const int64_t G = group;
  const int64_t D = C / G;
  const double s = 1.0 / static_cast<double>(D * HxW);
  Tensor dX;
  Tensor dgamma;
  Tensor dbeta;
  const Tensor X_tensor = X.reshape({N, G, D, HxW});
  const Tensor mean_tensor = mean.reshape({N, G, 1, 1});
  const Tensor rstd_tensor = rstd.reshape({N, G, 1, 1});
  Tensor dY_tensor;
  Tensor ds;
  Tensor db;
  if (dY.defined()) {
    dY_tensor = dY.reshape({N, G, D, HxW});
    ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
    db = dY_tensor.sum(3).unsqueeze_(-1);
  }
  if (grad_input_mask[0]) {
    Tensor gamma_tensor;
    if (isDefined(gamma)) {
      gamma_tensor = gamma->reshape({1, G, D, 1});
    }
    const Tensor var =
        ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
    const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
    Tensor dvar;
    if (drstd.defined()) {
      dvar = -0.5 * rstd_cube * drstd.view({N, G, 1, 1});
    }
    if (dY.defined()) {
      const Tensor a =
          isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
      Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
                     .unsqueeze_(-2);
      Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
                     .unsqueeze_(-2);
      b = (c * mean_tensor - b) * rstd_cube * s;
      c = -b * mean_tensor - c * rstd_tensor * s;
      dX = a * dY_tensor + b * X_tensor + c;
      if (dmean.defined() && drstd.defined()) {
        dX += var_std_mean_backward(
            {dvar, dmean.view({N, G, 1, 1})},
            X_tensor,
            var,
            mean_tensor,
            IntArrayRef{2, 3},
            0,
            true,
            false);
      }
      dX = dX.reshape_as(X);
    } else if (dmean.defined() && drstd.defined()) {
      dX = var_std_mean_backward(
               {dvar, dmean.view({N, G, 1, 1})},
               X_tensor,
               var,
               mean_tensor,
               IntArrayRef{2, 3},
               0,
               true,
               false)
               .reshape_as(X);
    }
  }
  if (grad_input_mask[1] && dY.defined()) {
    dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toNonOptTensor(gamma));
  }
  if (grad_input_mask[2] && dY.defined()) {
    dbeta = db.sum(0).reshape_as(toNonOptTensor(gamma));
  }

  return std::make_tuple(dX, dgamma, dbeta);
}

std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
                                                       IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3,
                                                       IntArrayRef sumdim, std::array<bool, 3> grad_mask) {
  Tensor grad_i1, grad_i2, grad_i3;
  if (grad_out.defined()) {
    if (grad_mask[0])
      grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
    if (grad_mask[1])
      grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
    if (grad_mask[2])
      grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
  }
  return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}

Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
  if (self.is_sparse()) {
    AT_ERROR(
      "log1p of a sparse tensor is made to be non-differentiable since ",
      "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ",
      "Use a different mathematical operation which preserves sparsity of gradients, ",
      "or report a bug if you think this is an error.");
  }
  return grad / (self + 1).conj();
}

Tensor sinc_backward(const Tensor& grad, const Tensor& self) {
  auto self_pi = self * M_PI;
  auto self_squared_pi = self * self * M_PI;
  auto out = grad * ((self_pi * self_pi.cos() - self_pi.sin()) / self_squared_pi).conj();
  return at::where(self_squared_pi == 0.0, at::zeros({}, grad.options()), out);
}

Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices) {
  return _sparse_mask_helper(sparse_grad_out.coalesce(), indices.contiguous());
}

// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
  auto negated_pad = pad.vec();
  // NOLINTNEXTLINE(modernize-use-transparent-functors)
  std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>());
  return at::constant_pad_nd(grad, negated_pad, 0);
}

Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
  // since first backward takes care of scaling by frequency,
  // we don't need to worry about it here.
  auto gg_weight = grad.index_select(0, indices.reshape(-1));

  // reshape gradient as per the shape of indices
  auto size = indices.sizes().vec();
  size.push_back(-1);

  if (padding_idx >= 0) {
    gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
  }
  return gg_weight.view(size);
}

Tensor index_backward(Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const Tensor& grad) {
  return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}

Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
  if (zero_infinity) {
    return at::where(
        loss.unsqueeze(0).unsqueeze(2) == 0,
        at::zeros({}, raw_grad.options()),
        raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
  } else {
    return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
  }
}

bool any_variable_defined(const variable_list& variables) {
  for (const auto& variable : variables) {
    if (variable.defined()) {
      return true;
    }
  }
  return false;
}

// Derivations for the householder_product.backward method.
//
// Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1, ..., tau_k,
// the torch.linalg.householder_product computes the firt n columns of the following product:
// Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k v_k^H).
// Let
//     H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ... H_k(sigma_k)[:, :k];
//     H_i_minus = H_1(tau_1) ... H_{i - 1}(tau_{i - 1}), with H_1_minus := I;
//     H_i_plus = H_{i + 1}(tau_{i + 1}) ... H_k(tau_k) with H_k_plus := I;
//
// Forward AD:
// dQ = sum_{i = 1}^k H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H) H_i_plus.
//
// Backward AD:
// Tr(Q_grad^H dQ) = sum_{i = 1}^k Tr(H_i_plus Q_grad^H H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H)).
// Let K_i := H_i_plus Q_grad^H H_i_minus, then the gradients are
// v_i_grad = (-tau_i v_i^H K_i)^H - tau_i K_i v_i,
// tau_i_grad = Tr(-v_i^H K_i v_i).conj().
// NOTE: the algorithms ignores that only n columns of Q are observed, so there is no need in
// recomputing Q to full completion.
//
// Note that K_{i + 1} = H_{i + 1}^{-1} K_i H_i, so we can compute v_i_grad, tau_i_grad one by one
// by just efficiently updating K_i if that is possible. Multiplying with H_i from the right could be
// done with matrix-vector products, but what about the inverse H_{i + 1}^{-1} and does it even exist?
// Luckily, under some assumptions, H_{i + 1}^{-1} exists and admits a representation as H_i(sigma_i) for some
// sigma_i, so the left update is also could be done with matrix-vector and not matrix-matrix products.
//
// Let H(tau) := I - tau v v^H.
// H(tau) has eigenvalues 1 with multiplicity (m - 1) with eigenvectors orthogonal to v,
// and an eigenvalue (1 - tau ||v||^2) with the corresponding eigenvector v / ||v||.
// If (1 - tau ||v||^2) != 0, H(tau) is invertible.
// If (1 - tau ||v||^2) != 0, then with sigma defined as
// sigma := tau / (||v||^2 tau - 1) we get that
// H(tau) H(sigma) = H(sigma) H(tau) = I, so H(sigma) is the inverse of H(tau).
//
// WARNING: the algorithm below assumes that H_i(tau_i) are all invertible, so
// it expects that (1 - tau_i ||v_i||^2) != 0 for all i.
// We would like to point out that if there is H_i(tau_i) which is not invertible,
// the householder_product is still differentiable! We will not be able to compute K_i
// efficiently in such cases, however, as evaluating of each K_i will amount to calls
// to ORGQR to be able to compute H_i_plus.

// This function computes either the product between
// (I - tau u v^H) and K (in-place or not) with `condition_with_I = true`, or between
// (-tau u v^H) and K (out-of-place only) with `condition_with_I = false`.
// Parameter `left` controls whether the matrix K is multiplied from the left or
// from the right.
// Additionally, when the computation is done in-place, we exploit that the first
// `k` coordinates of `u_full/v_full` are zeros.
Tensor apply_simple_transformation(
    int64_t m, int64_t k,
    const Tensor& u_full, const Tensor& v_full,
    const Tensor& t, Tensor& K,
    bool modify_K_in_place = true,
    bool condition_with_I = true,
    bool left = true) {
  // we assume u_full is a vector of dimension (..., m, 1), t is a scalar of dimension (..., 1)

  // TODO: matrix-vector products in the code below are dispatched to matrix-matrix products.
  // We either need to extend matmul to support batched matrix-vector products, or
  // implement a batched variant of mv.
  // We could enable mv for inputs which are not batched, but it is not done to eliminate
  // the code duplication.

  // returns (I - t u v^H) K or -t u v^H K
  if (left) {
    if (modify_K_in_place) {
      auto v = u_full.narrow(-2, k, m - k);
      auto u = v_full.narrow(-2, k, m - k).mH().matmul(K.narrow(-2, k, m - k));
      K.narrow(-2, k, m - k).sub_((t.unsqueeze(-1) * v) * u);
      return K;
    }
    else {
      auto transformation = (t.unsqueeze(-1) * u_full) * v_full.mH().matmul(K);
      return condition_with_I ? K - transformation : -transformation;
    }
  }
  // returns K (I - t u v^H) or -K t u v^H
  else {
    if (modify_K_in_place) {
      auto v = u_full.narrow(-2, k, m - k);
      auto u = K.narrow(-1, k, m - k).matmul(t.unsqueeze(-1) * v_full.narrow(-2, k, m - k));
      K.narrow(-1, k, m - k).sub_(u * v.mH());
      return K;
    }
    else {
      auto transformation = K.matmul(t.unsqueeze(-1) * u_full) * v_full.mH();
      return condition_with_I ? K - transformation : -transformation;
    }
  }
};

std::tuple<Tensor, Tensor> householder_product_backward(const Tensor& grad, const Tensor& result, const Tensor& input_, const Tensor& tau) {
  if (!grad.defined() || !input_.numel() || !tau.numel()) {
    return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
  }

  auto input_grad = at::zeros_like(input_);
  auto tau_grad = at::zeros_like(tau);

  auto m = input_.size(-2);
  auto k = tau.size(-1);

  // forward operates only over the lower triangular part with the assumption
  // that the diagonal of input is filled with 1s.
  auto input = input_.tril(-1);
  input.diagonal(0, -2, -1).fill_(1.0);

  // compute sigma such that
  // H(sigma_i) == H(tau_i)^{-1}.
  // If the input to householder_product comes from GEQRF,
  // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be invertible.
  // This follows from the documentation https://www.netlib.org/lapack/lug/node128.html,
  // and tau always satisfying the condition |tau|^2 ||v||^2 == 2 * Re(tau).
  auto input_first_k_cols = input.narrow(-1, 0, k);
  auto input_first_k_cols_norm_squared = (
    input_first_k_cols * input_first_k_cols.conj()
  ).sum(-2);
  auto sigma = tau / (tau * input_first_k_cols_norm_squared - 1.0);

  auto K = result.matmul(grad.mH());

  // The algorithm updates K by multiplying it from the left/right with Householder reflectors.
  // If only single backward is run, we modify K in-place and exploit triangularity of the input.
  // With higher order derivatives we cannot rewrite the storage of K, hence we use much less efficient
  // out-of-place methods.
  //
  // if only first-order derivative is expected, we can modify K in-place for better performance
  bool modify_K_in_place = !at::GradMode::is_enabled();

  // This method exploites that at k-th iteration vector v_k has only elements v_k[k:] which are non-zero.
  auto update_grad = [&m](int64_t k, const Tensor& v_full, const Tensor& t, const Tensor& K) -> std::tuple<Tensor, Tensor> {
    // v_full is a vector of dimension (..., m, 1), t is a scalar of dimension (..., 1)
    auto v = v_full.narrow(-2, k, m - k);
    auto vHK = v.mH().matmul(K.narrow(-2, k, m - k));
    auto Kv = K.narrow(-1, k, m - k).matmul(v);
    auto t_unsqueezed = t.unsqueeze(-1);
    auto v_grad = (-t_unsqueezed * vHK).conj().squeeze(-2) - (t_unsqueezed * Kv).squeeze(-1);
    auto tau_grad = -(vHK.narrow(-1, k, m - k).matmul(v)).conj();
    return std::make_tuple(v_grad, tau_grad.squeeze(-1));
  };

  auto apply_householder_reflector = [m, modify_K_in_place](
    int64_t k, const Tensor& v_full,
    const Tensor& t, Tensor& K,
    bool left = true) -> Tensor {
    return apply_simple_transformation(
      m, k, v_full, v_full, t, K, modify_K_in_place, /*condition_with_I=*/true, left
    );
  };

  // K <- H_0^{-1} @ K
  K = apply_householder_reflector(
    0, input.narrow(-1, 0, 1), sigma.narrow(-1, 0, 1),
    K, /*left=*/true
  );
  for (const auto i : c10::irange(k)) {
    // NOTE: narrow will unsqueeze(-1)
    auto v_i = input.narrow(-1, i, 1);
    auto t_i = tau.narrow(-1, i, 1);

    Tensor v_i_grad, tau_i_grad;
    std::tie(v_i_grad, tau_i_grad) = update_grad(i, v_i, t_i, K);
    input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1));
    tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1));

    // K <- H_{i + 1}^{-1} @ K @ H_i
    if (i < k - 1) {
      auto v_i_next = input.narrow(-1, i + 1, 1);
      auto s_i_next = sigma.narrow(-1, i + 1, 1);
      K = apply_householder_reflector(
        i + 1, v_i_next, s_i_next,
        K, /*left=*/true
      );
      K = apply_householder_reflector(
        i, v_i, t_i,
        K, /*left=*/false
      );
    }
  }

  // forward operates only over the lower-triangular part of the input
  // excluding the main diagonal, hence the gradient is also lower-triangular.
  input_grad.tril_(-1);

  return std::make_tuple(input_grad, tau_grad);
}

// We refer to the derivations described above the method `apply_simple_transformation`
Tensor householder_product_jvp(
    const Tensor& dV_,
    const Tensor& dtau,
    const Tensor& prod,
    const Tensor& V_,
    const Tensor& tau
) {
  auto m = V_.size(-2);
  auto k = tau.size(-1);

  // forward operates only over the lower triangular part with the assumption
  // that the diagonal of input is filled with 1s.
  auto V = V_.tril(-1);
  V.diagonal(0, -2, -1).fill_(1.0);
  auto dV = dV_.tril(-1);

  // compute sigma such that
  // H(sigma_i) == H(tau_i)^{-1}.
  // If the input to householder_product comes from GEQRF,
  // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be invertible.
  // This follows from the documentation https://www.netlib.org/lapack/lug/node128.html,
  // and tau always satisfying the condition |tau|^2 ||v||^2 == 2 * Re(tau).
  auto V_first_k_cols = V.narrow(-1, 0, k);
  auto V_first_k_cols_norm_squared = (
    V_first_k_cols * V_first_k_cols.conj()
  ).sum(-2);
  auto sigma = tau / (tau * V_first_k_cols_norm_squared - 1.0);

  auto apply_householder_reflector = [m](
    const Tensor& v_full,
    const Tensor& t, Tensor& K,
    bool left = true) -> Tensor {
    return apply_simple_transformation(
      // setting `modify_K_in_place = true` causes CUDA memory leaks in OpInfo tests of forward AD
      // for that reason we ignore `k` by passing zero
      m, /*k=*/0, v_full, v_full, t, K, /*modify_K_in_place=*/false, /*condition_with_I=*/true, left
    );
  };

  // computes (-t u v^H) K
  auto apply_simple_product = [m](
    const Tensor& u_full, const Tensor& v_full,
    const Tensor& t, Tensor& K
  ) -> Tensor {
    return apply_simple_transformation(
      // since ``modify_K_in_place = false`, we can ignore `k` and pass arbitrary value
      m, /*k=*/0, u_full, v_full, t, K, /*modify_K_in_place=*/false, /*condition_with_I=*/false, /*left=*/true
    );
  };


  auto H_plus = prod.detach().clone();
  IntArrayRef batch_vector_shape(V.sizes().data(), V.dim() - 1);
  auto H_minus = at::diag_embed(at::ones({1}, V.options()).expand(batch_vector_shape));

  auto dprod = at::zeros_like(prod);
  for (const auto i : c10::irange(k)) {
    auto v_i = V.narrow(-1, i, 1);
    auto dv_i = dV.narrow(-1, i, 1);
    auto tau_i = tau.narrow(-1, i, 1);
    auto dtau_i = dtau.narrow(-1, i, 1);
    auto sigma_i = sigma.narrow(-1, i, 1);

    H_plus = apply_householder_reflector(v_i, sigma_i, H_plus, /*left=*/true);

    dprod.add_(H_minus.matmul(
      apply_simple_product(v_i, v_i, dtau_i, H_plus)
      + apply_simple_product(dv_i, v_i, tau_i, H_plus)
      + apply_simple_product(v_i, dv_i, tau_i, H_plus)
    ));

    H_minus = apply_householder_reflector(v_i, tau_i, H_minus, /*left=*/false);
  }

  return dprod;
}

std::tuple<Tensor, Tensor> polar_backward(
    const Tensor& grad,
    const Tensor& result) {
  Tensor grad_abs, grad_angle;
  if (grad.defined()) {
    auto grad_conj = grad.conj();
    grad_abs = at::real(grad_conj * at::sgn(result));
    auto result_mul_1_j = result * Scalar(c10::complex<double>{0.0, 1.0});
    grad_angle = at::real(grad_conj * result_mul_1_j);
  }
  return std::make_tuple(grad_abs, grad_angle);
}

Tensor i1_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& result) {
  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1_backward", [&]() {
    // For x = 0, the correct gradient is 0.5,
    // however due to floating point computation we get NaN.
    // So we manually update gradient for x=0
    auto eps = std::numeric_limits<scalar_t>::epsilon();
    auto self_is_not_tiny = self.abs() > eps;

    // Following `where` is needed as `where` computes gradients,
    // even for the part which didn't affect the output.
    // Look at https://github.com/pytorch/pytorch/issues/52248
    // Update if and when this is fixed.
    auto safe_self =
        at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
    auto gradx = (safe_self.i0() - (result * safe_self.reciprocal()));
    return grad *
        at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
  });
}

Tensor i1e_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& result) {
  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1e_backward", [&]() {
    // For x = 0, the correct gradient is 0.5,
    // however due to floating point computation we get NaN.
    // So we manually update gradient for x=0
    auto eps = std::numeric_limits<scalar_t>::epsilon();
    auto self_is_not_tiny = self.abs() > eps;

    // Following `where` is needed as `where` computes gradients,
    // even for the part which didn't affect the output.
    // Look at https://github.com/pytorch/pytorch/issues/52248
    // Update if and when this is fixed.
    auto safe_self =
        at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
    auto gradx =
        (at::special_i0e(safe_self) -
         result * (safe_self.sgn() + safe_self.reciprocal()));
    return grad *
        at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
  });
}

// lu_solve is a map (LU, P, B) -> (PLU)^{-1} B,
// where LU = L + U - I and P is a permutation matrix, and is fixed.
//
// Let 1 = ones_like(LU),
// 1_U = 1.triu(),
// 1_L = 1.tril(-1)
// * := the Hadamard (element-wise) product
//
// Forward AD:
//
// Let X := U^{-1} L^{-1} P^T B be the output of the function.
// Also, the LU input of the function could be represented as
// LU = (L - I) + U.
//
// Differentiating LU = L + U - I produces:
// dLU = dL + dU.
// Noting that dL and dU are lower- and upper-triangular, respectively,
// and that the diagonal of L is never explicitly exposed, so
// diag(dL) = 0, it follows
// dL = dLU * 1_L,
// dU = dLU * 1_U.
//
// Differentiating X = U^{-1} L^{-1} P^T B produces:
// dX = dU^{-1} L^{-1} P^T B + U^{-1} dL^{-1} P^T B + U^{-1} L^{-1} P^T dB
// Note that for any invertible matrix A we have A A^{-1} = I, hence
// dA A^{-1} + A dA^{-1} = 0 => dA^{-1} = -A^{-1} dA A^{-1}.
// Inserting it back into the definition of dX gives:
// dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} L^{-1} P^T dB
// dX = -U^{-1} dU X - U^{-1} L^{-1} dL U X + U^{-1} L^{-1} P^T dB
//
// Backward AD:
//
// Using the definition of dL, dU from above:
// Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H (dLU * 1_U))
//                                   = [using Tr(A (B * C)) = Tr((A * B^T) C)
//                                   = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H * 1_U^T) dLU),
// hence
// LU_grad = L_grad * 1_L + U_grad * 1_U (!!!)
//
// Then, transposing the formula for dX above we get:
// B_grad = P L^{-H} U^{-H} X_grad = lu_solve(X_grad, LU_data, LU_pivots, /*adjoint=*/true)
// U_grad = -U^{-H} X_grad X^H
// L_grad = L^{-H} U_grad U^H
// After inserting U_grad and L_grad into (!!!) we get the value for LU_grad.

std::tuple<Tensor, Tensor> lu_solve_backward(
  const Tensor& grad,
  const Tensor& X,
  const Tensor& LU_data,
  const Tensor& LU_pivots,
  const std::array<bool, 2>& grad_input_mask) {
  const bool B_requires_grad = grad_input_mask[0];
  const bool LU_data_requires_grad = grad_input_mask[1];
  if (!grad.defined() || (!B_requires_grad && !LU_data_requires_grad)) {
    return std::make_tuple(Tensor{}, Tensor{});
  }

  // TODO If just B requires grad, the following formula is better:
  //const auto trans = grad.is_complex() ? TransposeType::ConjTranspose : TransposeType::Transpose;
  //const Tensor B_grad = at::_lu_solve_trans(grad, LU_data, LU_pivots, trans);
  //return std::make_pair(B_grad, Tensor{});
  //
  // We'll be able to use it once we ahve migradet lu_solve to linalg and has an `adjoint` flag.
  // This formula avoids the instantiation of P explicitly and may have better numerical properties

  const Tensor X_H = X.mH();
  Tensor P, L, U;
  if (B_requires_grad) {
      std::tie(P, L, U) = at::lu_unpack(LU_data, LU_pivots);
  } else {
    std::tie(std::ignore, L, U) = at::lu_unpack(LU_data, LU_pivots,
                                                /*unpack_data=*/true,
                                                /*unpack_pivots=*/false);
  }
  const Tensor U_H = U.mH();
  const Tensor L_H = L.mH();

  if (B_requires_grad) {
      // Y = U^{-H}X_grad
      const Tensor Y = at::linalg_solve_triangular(U_H, grad, /*upper=*/false);
      // Z = L^{-H}U^{-H}X_grad
      const Tensor Z = at::linalg_solve_triangular(L_H, Y,
                                                   /*upper=*/true,
                                                   /*left=*/true,
                                                   /*unitriangular=*/true);
      const Tensor B_grad = P.matmul(Z);
      Tensor LU_data_grad;
      if (LU_data_requires_grad) {
        const Tensor U_grad = Y.matmul(X_H);
        const Tensor L_grad = Z.matmul(X_H).matmul(U_H);
        LU_data_grad = -(L_grad.tril(-1) + U_grad.triu());
      }
      return std::make_pair(B_grad, LU_data_grad);
  } else {
    // Since when nothing needs to be computed was handled at the start, here we have
    // the case when just LU_data requires grad

    // U^{-H}X_grad X^H
    const Tensor U_grad = at::linalg_solve_triangular(U_H, grad.matmul(X_H), /*upper=*/false);
    // L^{-H}U^{-H}X_grad X^H U^H
    const Tensor L_grad = at::linalg_solve_triangular(L_H, U_grad.matmul(U_H),
                                                      /*upper=*/true,
                                                      /*left=*/true,
                                                      /*unitriangular=*/true);

    // LU_data_grad = L_grad * 1_L + U_grad * 1_U
    const Tensor LU_data_grad = -(L_grad.tril(-1) + U_grad.triu());
    return std::make_pair(Tensor{}, LU_data_grad);
  }
}

Tensor lu_solve_jvp(
  const Tensor& X,
  const Tensor& LU_data,
  const Tensor& dLU_data,
  const Tensor& dB,
  const Tensor& LU_pivots
) {
  Tensor L, U, dL, dU;
  std::tie(std::ignore, L, U) = at::lu_unpack(LU_data, LU_pivots, /*unpack_data=*/true, /*unpack_pivots=*/false);
  dL = dLU_data.tril(-1);
  dU = dLU_data.triu();
  auto dA = dL.matmul(U) + L.matmul(dU);
  return generic_solve_jvp(
    [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
      // We exploit the structure in the computation of A^{-1} dA, where the permutation matrix P such that A = P L U
      // cancels itself. Because of that we use lu_solve with the identity permutation input.
      // The identity permutation pivots are 1-based because of the Fortran-like LAPACK interfaces.
      // More details on the permutation matrix canceling note:
      // as part of forward AD we need to compute A^{-1} dA.
      // Since A = P L U and P is locally constant for full-rank matrices, we get
      // dA = P d(L U), A^{-1} = (L U)^{-1} P^T, so
      // A^{-1} dA = (L U)^{-1} d(L U), which is lu_solve with
      // the pivots set to the identity permutation
      auto identity_pivots = at::arange(1, LU_data.size(-1) + 1, LU_pivots.options()).expand(LU_pivots.sizes());
      return at::lu_solve(dB, A, LU_pivots) - at::lu_solve(dA_contrib, A, identity_pivots);
    },
    X, /*A=*/LU_data, dA, dB
  );
}

Tensor lu_unpack_backward(
  const variable_list& grads,
  const Tensor& LU_data,
  bool unpack_data
) {
  auto L_grad = grads[1];
  auto U_grad = grads[2];

  auto m = LU_data.size(-2);
  auto n = LU_data.size(-1);
  auto k = std::min(m, n);

  TORCH_CHECK(unpack_data, "lu_unpack_backward: cannot compute gradients unless unpack_data=True");

  auto res = at::zeros(LU_data.sizes(), LU_data.options());

  Tensor L_grad_contrib;
  if (L_grad.defined()) {
    L_grad_contrib = L_grad.tril();
    L_grad_contrib.diagonal(0, -2, -1).fill_(0);
    res.narrow(-2, 0, m).narrow(-1, 0, k).add_(L_grad_contrib);
  }

  Tensor U_grad_contrib;
  if (U_grad.defined()) {
    U_grad_contrib = U_grad.triu();
    res.narrow(-2, 0, k).narrow(-1, 0, n).add_(U_grad_contrib);
  }

  return res;
}

Tensor cat_jvp(at::TensorList tensors, int64_t dim) {
  Tensor out_fw_grad;

  auto any_defined = false;
  for (const auto& t: tensors) {
    any_defined |= isFwGradDefined(t);
  }

  if (any_defined) {
    std::vector<Tensor> fw_grads;

    for (auto& t: tensors) {
      fw_grads.push_back(isFwGradDefined(t)? t._fw_grad(/*level*/ 0): at::zeros_like(t));
    }

    out_fw_grad = at::cat(fw_grads, dim);
  }

  return out_fw_grad;
}

Tensor stack_jvp(at::TensorList tensors, int64_t dim) {
  // Basically copy of cat_jvp above
  // TOD0: consolidate with the logic of cat_jvp
  Tensor out_fw_grad;

  auto any_defined = false;
  for (const auto& t: tensors) {
    any_defined |= isFwGradDefined(t);
  }

  if (any_defined) {
    std::vector<Tensor> fw_grads;

    for (auto& t: tensors) {
      fw_grads.push_back(isFwGradDefined(t)? t._fw_grad(/*level*/ 0): at::zeros_like(t));
    }
    out_fw_grad = at::stack(fw_grads, dim);
  }
  return out_fw_grad;
}

Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim) {
  // Generic formula when no 0. is involved
  Tensor gradient = (self_t / self_p).cumsum(dim) * result;

  // Note that we have to use at::where below as we are removing nans

  if (self_p.dim() == 0) {
    gradient.masked_fill_(self_p.eq(0), self_t);
    return gradient;
  } else {
    // For input (a, 0, b, 0, c) with gradients (t0, t1, t2, t3, t4)
    // The output of cumprod is (a, 0, 0, 0, 0)
    // The gradient we want to compute is (t0, a*t1, a*b*t1, 0, 0)
    // We do this by:
    // Get a mask of all zeros (0, 1, 0, 1, 0)
    auto mask_zeros = self_p.eq(0);
    // Get a mask of the first zero for each dim (0, 1, 0, 0, 0)
    auto mask_first_zero = mask_zeros.logical_and(mask_zeros.cumsum(dim).eq(1));

    // Get the new grad value that should be used after any zero happened:
    // (X, a*t1, a*b*t1, 0, 0) = cumprod((a, t1, b, 0, c))
    auto new_grad = at::where(mask_first_zero, self_t, self_p).cumprod(dim);

    // Get a mask of everything after the first zero: (0, 1, 1, 1, 1)
    auto mask_after_first_zero = mask_first_zero.cumsum(dim);

    // Do the final replacement
    return at::where(mask_after_first_zero.to(ScalarType::Bool), new_grad, gradient);
  }
}

// Helper for {batch,layer,group}_norms below
// Computes the jvp for `1 / input.std(dims, keepdim)`
static Tensor _invstd_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& mean_p, const Tensor& invstd_p,
    IntArrayRef dims, int64_t numel, bool keepdim) {
  Tensor invstd_t;
  if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) {
    invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) * (input_p - mean_p);
  } else {
    invstd_t = input_t - input_t.mean(dims, true);
    invstd_t *= input_p - mean_p;
    invstd_t *= -invstd_p.pow(3);
  }
  invstd_t = invstd_t.sum(dims, keepdim);
  invstd_t /= numel;
  return invstd_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)`
static Tensor _norm_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& mean_p, const Tensor& invstd_p,
    IntArrayRef dims, int64_t numel) {
  auto invstd_t = _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true);
  Tensor result_t;
  if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) || input_t._is_zerotensor()) {
    result_t = (input_t - input_t.mean(dims, true)) * invstd_p + (input_p - mean_p) * invstd_t;
  } else {
    result_t = input_t - input_t.mean(dims, true);
    result_t *= invstd_p;
    auto temp = input_p - mean_p;
    temp *= invstd_t;
    result_t += temp;
  }
  return result_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `input * weight + bias` where weight and bias may be undefined
// Possibly modifies the input inplace
static Tensor _affine_jvp(
    const c10::optional<Tensor>& input_p, Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_t) {
  // We allow input_p to be optional because if weight_p isn't defined,
  // it may be possible to avoid computing input_p
  TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
  if (weight_p.defined()) {
    if (areAnyTensorSubclassLike({input_p.value(), input_t, weight_p, weight_t}) || input_t._is_zerotensor() || weight_t._is_zerotensor()) {
      input_t = input_t * weight_p + input_p.value() * weight_t;
    } else {
      input_t *= weight_p;
      auto temp = input_p.value();
      temp *= weight_t;
      input_t += temp;
    }
  }
  if (bias_t.defined()) {
    if (areAnyTensorSubclassLike({input_t, bias_t}) || input_t._is_zerotensor()) {
      input_t = input_t + bias_t;
    } else {
      input_t += bias_t;
    }
  }
  return input_t;
}

Tensor batch_norm_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_p, const Tensor& bias_t,
    const c10::optional<Tensor>& running_mean,
    const c10::optional<Tensor>& running_var,
    const Tensor& saved_mean, const Tensor& saved_invstd,
    bool train,
    double eps) {
  auto dims = std::vector<int64_t>{};
  auto view_size = input_t.sizes().vec();
  int64_t numel = 1;
  for (const auto dim : c10::irange(view_size.size())) {
    if (dim != 1) {
      numel *= input_t.size(dim);
      view_size[dim] = 1;
      dims.push_back(dim);
    }
  }
  Tensor mean_p;
  Tensor invstd_p;
  Tensor result_t;
  if (train) {
    mean_p = saved_mean.view(view_size);
    invstd_p = saved_invstd.view(view_size);
    result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
  } else {
    TORCH_INTERNAL_ASSERT(
        running_mean.has_value() && running_var.has_value(),
        "Expect running_mean and running_var to have value when train=false");
    mean_p = running_mean.value().view(view_size);
    invstd_p = (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
    result_t = input_t * invstd_p;
  }

  c10::optional<Tensor> result_p = weight_p.defined()
    ? c10::optional<Tensor>((input_p - mean_p) * invstd_p) : c10::nullopt;
  return _affine_jvp(
      result_p, result_t,
      weight_p.defined() ? weight_p.view(view_size) : weight_p,
      weight_t.defined() ? weight_t.view(view_size) : weight_t,
      bias_t.defined() ? bias_t.view(view_size) : bias_t);
}

Tensor batch_norm_jvp_saved_var(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_p, const Tensor& bias_t,
    const c10::optional<Tensor>& running_mean,
    const c10::optional<Tensor>& running_var,
    const Tensor& saved_mean, const Tensor& saved_var,
    bool train,
    double eps) {
  auto saved_invstd = (1 / at::sqrt(saved_var + at::Scalar(eps)));
  return batch_norm_jvp(
      input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var,
      saved_mean, saved_invstd, train, eps);
}

Tensor layer_norm_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_p, const Tensor& bias_t,
    const Tensor& saved_mean, const Tensor& saved_invstd,
    IntArrayRef normalized_shape) {
  auto dims = std::vector<int64_t>{};
  auto view_size = input_t.sizes().vec();
  auto view_size_affine = input_t.sizes().vec();

  int64_t numel = 1;
  for (const auto i : c10::irange(view_size.size())) {
    if (i < view_size.size() - normalized_shape.size()) {
      view_size_affine[i] = 1;
    } else {
      numel *= input_t.size(i);
      view_size[i] = 1;
      dims.push_back(i);
    }
  }
  auto mean_p = saved_mean.view(view_size);
  auto invstd_p = saved_invstd.view(view_size);
  auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);

  c10::optional<Tensor> result_p = weight_p.defined()
    ? c10::optional<Tensor>((input_p - mean_p) * invstd_p) : c10::nullopt;
  return _affine_jvp(
      result_p, result_t,
      weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
      weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
      bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
}

Tensor group_norm_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& weight_p, const Tensor& weight_t,
    const Tensor& bias_p, const Tensor& bias_t,
    const Tensor& saved_mean, const Tensor& saved_invstd,
    int64_t groups) {
  auto input_shape = input_p.sizes();
  int64_t N = input_p.size(0);
  int64_t C = input_p.size(1);

  auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
  auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});

  auto result_t = batch_norm_jvp(
      input_p_reshaped, input_t_reshaped,
      /*weight_p=*/{}, /*weight_t=*/{},
      /*bias_p=*/{}, /*bias_t=*/{},
      /*running_mean=*/{}, /*running_var=*/{},
      saved_mean, saved_invstd, /*train=*/true, /*eps=*/0).view(input_shape);

  c10::optional<Tensor> result_p = c10::nullopt;
  if (weight_p.defined()) {
    std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
    view_size[1] = input_t_reshaped.size(1);
    result_p = ((input_p_reshaped - saved_mean.view(view_size)) * saved_invstd.view(view_size)).view(input_shape);
  }
  std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
  affine_param_shape[1] = C;

  return _affine_jvp(
      result_p, result_t,
      weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p,
      weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t,
      bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t);
}

Tensor group_norm_mean_jvp(
    const Tensor& input_t, const Tensor& mean_p, int64_t groups) {
  int64_t N = input_t.size(0);
  int64_t C = input_t.size(1);
  std::array<int64_t, 3> view_shape = {1, N * groups, N ? -1 : 1};
  auto input_t_reshaped = input_t.view(view_shape);
  return input_t_reshaped.mean({2}, false).view_as(mean_p);
}

Tensor group_norm_invstd_jvp(
    const Tensor& input_p, const Tensor& input_t,
    const Tensor& mean_p, const Tensor& invstd_p,
    int64_t groups) {
  int64_t N = input_p.size(0);
  int64_t C = input_p.size(1);

  std::vector<int64_t> view_shape = {1, N * groups, N ? -1 : 1};

  auto input_t_reshaped = input_t.view(view_shape);
  auto input_p_reshaped = input_p.view(view_shape);

  return _invstd_jvp(
      input_t_reshaped, input_p_reshaped, mean_p.view(view_shape), invstd_p.view(view_shape),
      /*dims=*/{2}, /*numel=*/input_t_reshaped.size(2), /*keepdim=*/false).view_as(invstd_p);
}

Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim) {
  auto full_indices = indices;
  if (!keepdim) {
    full_indices = indices.unsqueeze(dim);
  }
  auto out_fw_grad = at::gather(input, dim, full_indices);
  if (!keepdim) {
    out_fw_grad = out_fw_grad.squeeze(dim);
  }

  return out_fw_grad;
}

// Let X in \C^{m \times n}, then its pivoted LU decomposition is
// X = P L U, where P is a permutation matrix.
//
// Useful notation:
// Let o denote the elementwise, or Hadamard, product.
// k := min(m, n)
// 1 := ones(k, k),
// 1_U = 1.tril();
// 1_L = 1 - 1_U (note the diagonal is zero)
// For a matrix A, A^H := A.mH()
//
// Below we derive the backward algorithm for the case when m <= n.
// The case m > n could be obtained using the same idea.
// Since we assume m <= n, the LU decomposition of X could be written as
// X = (X1 | X2) = P L (U1 | U2) where X1, U1 in \C^{m \times m}, X2, U2 in \C^{m, n - m}
//
// Forward AD:
//
// dX = P dL U + P L dU => [left-multiply P^T]
// (P^T dX1 | P^T dX2) = (dL U1 + L dU1 | dL U2 + L dU2) (*)
// From (*):
// P^T dX1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by U1^{-1}]
// L^{-1} P^T dX1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**).
// Note, L is lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular.
// Also, since the diagonal of L (all ones) is never exposed explicity (packed representation),
// the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0.
// Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular.
// Combining these observations we conclude:
//
// L^{-1} dL = (L^{-1} P^T dX1 U1^{-1}) o 1_L,
// dU1 U1^{-1} = (L^{-1} P^T dX1 U1^{-1}) o 1_U.
//
// Hence,
// dL = L [(L^{-1} P^T dX1 U1^{-1}) o 1_L],
// dU1 = [(L^{-1} P^T dX1 U1^{-1}) o 1_U] U1.
// As for dU2, from (*) it follows
// P^T dX2 = dL U2 + L dU2 =>
// dU2 = L^{-1} (P^T dX2 - dL U2).
//
// Backward AD:
//
// The following equality comes very handy:
// Tr(A (B o C)) = Tr((A o B^T) C) (!)
//
// Tr(X_grad^H dX) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then
//
// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dX1 U1^{-1}) o 1_L] = [using (!)]
//                 = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dX1 U1^{-1}) = [using the cyclic property of Tr]
//                 = Tr(U1^{-1} (L_grad^H L o 1_L^T) L^{-1} P^T dX1)
//
// Similar, using (!) and the cyclic property of the trace operator:
// Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2)
//                 = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dX1)
//                 + Tr(U2_grad^H L^{-1} P^T dX2)
//                 - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dX1)
//
// By combining the matrices to the left from dX1 and dX2 and then applying conjugate transposition,
// we finally arrive at:
//
// X1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H},
// X2_grad = P L^{-H} U2_grad
Tensor plu_backward_base(
  const variable_list& grads,
  const Tensor& self,
  const Tensor& P,
  const Tensor& L,
  const Tensor& U) {
  auto L_grad = grads[0];
  auto U_grad = grads[1];

  auto m = self.size(-2);
  auto n = self.size(-1);
  auto k = std::min(m, n);

  auto L_principal = L.narrow(-2, 0, k).narrow(-1, 0, k);
  auto L_principal_H = L_principal.mH();
  auto L_grad_principal = L_grad.narrow(-2, 0, k).narrow(-1, 0, k);
  auto U_principal = U.narrow(-2, 0, k).narrow(-1, 0, k);
  auto U_principal_H = U_principal.mH();
  auto U_grad_principal = U_grad.narrow(-2, 0, k).narrow(-1, 0, k);

  auto phi_L = L_principal_H.matmul(L_grad_principal).tril(-1);
  auto phi_U = U_grad_principal.matmul(U_principal_H).triu();

  auto phi = phi_L + phi_U;

  Tensor self_grad;
  if (m <= n) {
    auto U_complement = U.narrow(-2, 0, k).narrow(-1, k, n - k);
    auto U_grad_complement = U_grad.narrow(-2, 0, k).narrow(-1, k, n - k);

    auto phi_complement = U_grad_complement.matmul(U_complement.mH()).tril(-1);

    // recall the result for X1_grad and X2_grad from above.
    // It can be rewritten as
    // (X1_grad | X2_grad) = P L^{-H} psi, where
    // psi = (psi1 | psi2)
    //     = ([L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H} | U2_grad),
    // so it is filled in parts.

    // solve for psi1 to avoid the inversion of U1^H
    auto psi_principal = at::linalg_solve_triangular(U_principal_H, phi - phi_complement,
                                                     /*upper=*/false,
                                                     /*left=*/false,
                                                     /*unitriangular=*/false);
    auto psi = at::cat({psi_principal, U_grad_complement}, /*dim=*/-1);

    self_grad = P.matmul(at::linalg_solve_triangular(L_principal_H, psi,
                                                     /*upper=*/true,
                                                     /*left=*/true,
                                                     /*unitriangular=*/true));
  }
  else {
    // variables psi and phi carry the same meaning as in the case (m <= n),
    // albeit they are differently defined.
    auto L_complement = L.narrow(-2, k, m - k).narrow(-1, 0, k);
    auto L_grad_complement = L_grad.narrow(-2, k, m - k).narrow(-1, 0, k);

    auto phi_complement = L_complement.mH().matmul(L_grad_complement).triu();


    auto psi_principal = at::linalg_solve_triangular(L_principal_H, phi - phi_complement,
                                                     /*upper=*/true,
                                                     /*left=*/true,
                                                     /*unitriangular=*/true);
    auto psi = at::cat({psi_principal, L_grad_complement}, -2);

    self_grad = at::linalg_solve_triangular(U_principal_H, P.matmul(psi),
                                            /*upper=*/false,
                                            /*left=*/false,
                                            /*unitriangular=*/false);
  }

  return self_grad;
}

Tensor lu_factor_ex_backward(
  const Tensor& grad,
  const Tensor& self,
  const Tensor& LU,
  const Tensor& pivs) {
  Tensor P, L, U;
  std::tie(P, L, U) = at::lu_unpack(LU, pivs);
  // Note that packed LU could be represented as
  // LU = L + U - I, hence
  // L_grad = LU_grad,
  // U_grad = LU_grad.
  return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U);
}

Tensor lu_factor_ex_jvp(
  const Tensor& dA,
  const Tensor& LU,
  const Tensor& pivs
) {
  // This function is based on the forward AD derivations outlined
  // in the description to the plu_backward_base function.

  Tensor P, L, U;
  std::tie(P, L, U) = at::lu_unpack(LU, pivs);

  auto m = LU.size(-2);
  auto n = LU.size(-1);
  auto k = std::min(m, n);

  auto PdA = P.mT().matmul(dA);

  // similar to the backward implementation, we also consider block structures such as:
  // for a matrix A of size m x n we decompose it as
  // A = (A1 | A2) with A1 of size m x m if m <= n and
  // A = (A1^T | A2^T)^T with A1 of size n x n if m > n.
  auto PdA1 = PdA.narrow(-2, 0, k).narrow(-1, 0, k);
  auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k);
  auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k);

  // dK = L1^{-1} PdA1
  auto dK = at::linalg_solve_triangular(L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/true);
  // dK <- dK U1^{-1}
  dK = at::linalg_solve_triangular(U1, dK, /*upper=*/true, /*left=*/false);

  auto dL1 = L1.matmul(dK.tril(-1));
  auto dU1 = dK.triu().matmul(U1);

  // since LU = L + U - I, we have that dLU = dL + dU
  // if LU is of size m x n, we always have
  // dLU1 = dL1 + dU1, where the block indexing follows the rules
  // outlined above.
  if (m == n) {
    return dL1 + dU1;
  }
  else {
    auto dLU1 = dL1 + dU1;

    if (m < n) {
      // we only need to update dLU2 defined as
      // dLU2 := L1^{-1} PdA2 - dK.tril(-1) U2
      auto PdA2 = PdA.narrow(-1, k, n - k);
      auto U2 = U.narrow(-1, k, n - k);
      auto dLU2 = at::linalg_solve_triangular(L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/true) - dK.tril(-1).matmul(U2);
      return at::cat({dLU1, dLU2}, /*dim=*/-1);
    }
    else {
      // we only need to update dLU2 defined as
      // dLU2 := PdA2 U1^{-1} - L2 dK.triu()
      auto PdA2 = PdA.narrow(-2, k, m - k);
      auto L2 = L.narrow(-2, k, m - k);
      auto dLU2 = at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) - L2.matmul(dK.triu());
      return at::cat({dLU1, dLU2}, /*dim=*/-2);
    }
  }
}

Tensor warn_backwards(const Tensor &grad_output) {
  TORCH_WARN("Warn from backward");
  return grad_output;
}

// This function only exists because cuDNN does not support bias gradient computation and it's not easy
// to slice a std::tuple to return only grad_input / grad_weight from convolution_backward. It will
// be removed when the cudnn_convolution and cudnn_convolution_transpose go away.
std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
    const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding,
    at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, bool transposed, int64_t groups,
    ::std::array<bool,2> output_mask) {
  if (!grad_output.defined()) {
    return std::tuple<Tensor, Tensor>();
  }

  // Just call the general backward and ignore the bias gradient part.
  std::tuple<Tensor, Tensor, Tensor> grad_inputs = at::convolution_backward(
      grad_output, self, weight, c10::nullopt, stride, padding, dilation, transposed,
      output_padding, groups, {output_mask[0], output_mask[1], false});
  std::tuple<Tensor, Tensor> result = std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs));
  return result;
}

} // namespace details
} // namespace generated
} // namespace autograd
} // namespace torch
